diff --git a/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/conversions/tosa/tests/binary.mlir index 71c4392adb..c1f26771ee 100644 --- a/stablehlo/conversions/tosa/tests/binary.mlir +++ b/stablehlo/conversions/tosa/tests/binary.mlir @@ -50,6 +50,15 @@ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi return %0 : tensor<10xi32> } +// CHECK-LABEL: @divide_f32 +func.func @divide_f32(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> { + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>} + // CHECK-DAG: %[[VAR1:.*]] = tosa.reciprocal %arg1 + // CHECK: tosa.mul %arg0, %[[VAR1]], %[[VAR0]] + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> +} + // CHECK-LABEL: @dot_vector_vector func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor { // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0> diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index 59211515f8..be960db6b5 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -523,6 +523,41 @@ struct ConvertStablehloReshapeOp } }; +struct ConvertStablehloFloatDivideOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::DivOp op, + PatternRewriter& rewriter) const override { + auto lhsType = dyn_cast(op.getLhs().getType()); + auto rhsType = dyn_cast(op.getRhs().getType()); + if (!lhsType || !rhsType) { + return rewriter.notifyMatchFailure(op, "expected ranked tensor types"); + } + + if (!llvm::isa(lhsType.getElementType()) && + !llvm::isa(rhsType.getElementType())) { + return rewriter.notifyMatchFailure( + op, "only converts floating point division"); + } + + auto shiftTensorType = RankedTensorType::get({1}, rewriter.getI8Type()); + auto zeroShiftValue = DenseElementsAttr::get( + shiftTensorType, rewriter.getIntegerAttr(rewriter.getI8Type(), 0)); + auto shiftConst = tosa::ConstOp::create(rewriter, op.getLoc(), + shiftTensorType, zeroShiftValue); + + auto reciprocalOp = + tosa::ReciprocalOp::create(rewriter, op.getLoc(), rhsType, op.getRhs()); + + auto mulOp = tosa::MulOp::create(rewriter, op.getLoc(), op.getType(), + op.getLhs(), reciprocalOp, shiftConst); + + rewriter.replaceOp(op, mulOp.getResult()); + return success(); + } +}; + LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) { RewritePatternSet patternList(ctx); populateGeneratedPDLLPatterns(patternList); @@ -543,6 +578,8 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) { patternList.addWithLabel({"StablehloWhile"}, ctx); patternList.addWithLabel({"StablehloReshape"}, ctx); + patternList.addWithLabel( + {"StablehloFloatDivide"}, ctx); patterns = std::move(patternList); return success();