diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 6e1e3343ac169..e18fa849e9f30 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp( return nullptr; } -static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, - int64_t rank) { - // No need to expand if we are already at the desired rank - auto tensorType = dyn_cast(tensor.getType()); - assert(tensorType && "expected a ranked tensor type"); - int64_t tensorRank = tensorType.getRank(); - int64_t numExtraDims = rank - tensorRank; - assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank"); - if (!numExtraDims) - return tensor; - - // Compute reassociation indices - SmallVector reassociationIndices(tensorRank); - int64_t index = 0; - if (tensorRank != 0) { - for (index = 0; index <= numExtraDims; index++) - reassociationIndices[0].push_back(index); - for (size_t position = 1; position < reassociationIndices.size(); - position++) - reassociationIndices[position].push_back(index++); - } - - // Compute result type - SmallVector resultShape; - for (index = 0; index < numExtraDims; index++) - resultShape.push_back(1); - for (auto size : tensorType.getShape()) - resultShape.push_back(size); - auto resultType = - RankedTensorType::get(resultShape, tensorType.getElementType()); - - // Emit 'tensor.expand_shape' op - return rewriter.create(loc, resultType, tensor, - reassociationIndices); -} - -static SmallVector expandInputRanks(PatternRewriter &rewriter, - Location loc, ValueRange operands, - int64_t rank) { - return llvm::map_to_vector(operands, [&](Value operand) { - return expandRank(rewriter, loc, operand, rank); - }); -} - using IndexPool = DenseMap; // Emit an 'arith.constant' op for the given index if it has not been created @@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, return success(); } +static ValueRange getBroadcastableOperands(Operation *operation, + ValueRange operands) { + // Shift cannot broadcast + if (isa(operation)) + return operands.take_front(2); + // Input1_zp and output_zp cannot broadcast + if (isa(operation)) + return operands.take_front(1); + return operands; +} + static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, @@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, // Lower operation IndexPool indexPool; auto loc = operation->getLoc(); - auto rank = - cast(operation->getResultTypes().front()).getRank(); - // For the mul op we need to avoid expanding the rank of the optional shift - // input. - auto operandsToExpand = - isa(operation) ? operands.take_front(2) : operands; - - auto expandedOperands = - expandInputRanks(rewriter, loc, operandsToExpand, rank); + auto operandsToBroadcast = getBroadcastableOperands(operation, operands); auto [targetShape, masterOperands] = - computeTargetShape(rewriter, loc, indexPool, expandedOperands); - auto broadcastOperands = broadcastDynamicDimensions( - rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands); + computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast); + auto broadcastOperands = + broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast, + targetShape, masterOperands); return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands, targetShape, converter); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 18ce8571eeea0..9258442de5a45 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic - // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32): + // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32): // CHECK: [[ZERO:%.+]] = arith.constant 0 // CHECK: arith.subi [[ZERO]], %[[ARG1]] %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> @@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () { // CHECK-LABEL: @test_negate_quantized func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8 + // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8 // CHECK: [[CNST:%.+]] = arith.constant 7 // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16 // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]] @@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8 + // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8 // CHECK: [[C_128:%.+]] = arith.constant -128 // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16 // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]] @@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> ( return } + +// ----- + +// CHECK-LABEL: @test_0d_input +func.func @test_0d_input(%arg0: tensor) -> () { + // CHECK: linalg.generic + // CHECK: arith.muli + %shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg0, %shift1 : (tensor, tensor, tensor<1xi8>) -> tensor + + // CHECK: linalg.generic + // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32): + // CHECK: [[ZERO:%.+]] = arith.constant 0 + // CHECK: arith.subi [[ZERO]], %[[ARG1]] + %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + + return +}