From 25235355257bf0fcc4b7c51ada9c7e63117bd330 Mon Sep 17 00:00:00 2001 From: vvish Date: Sun, 9 Nov 2025 15:58:48 +0100 Subject: [PATCH 1/2] Added stablehlo divide fp to tosa reciprocal+mul conversion --- stablehlo/conversions/tosa/tests/binary.mlir | 9 +++++ .../transforms/StablehloLegalizeToTosa.cpp | 35 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/stablehlo/conversions/tosa/tests/binary.mlir b/stablehlo/conversions/tosa/tests/binary.mlir index 71c4392adb..c1f26771ee 100644 --- a/stablehlo/conversions/tosa/tests/binary.mlir +++ b/stablehlo/conversions/tosa/tests/binary.mlir @@ -50,6 +50,15 @@ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi return %0 : tensor<10xi32> } +// CHECK-LABEL: @divide_f32 +func.func @divide_f32(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> { + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>} + // CHECK-DAG: %[[VAR1:.*]] = tosa.reciprocal %arg1 + // CHECK: tosa.mul %arg0, %[[VAR1]], %[[VAR0]] + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> +} + // CHECK-LABEL: @dot_vector_vector func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor { // CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0> diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index 59211515f8..c2492d08fc 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -523,6 +523,39 @@ struct ConvertStablehloReshapeOp } }; +struct ConvertStablehloFloatDivideOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::DivOp op, + PatternRewriter& rewriter) const override { + auto lhsType = dyn_cast(op.getLhs().getType()); + auto rhsType = dyn_cast(op.getRhs().getType()); + if (!lhsType || !rhsType) { + return rewriter.notifyMatchFailure(op, "expected ranked tensor types"); + } + + if (!llvm::isa(lhsType.getElementType()) && + !llvm::isa(rhsType.getElementType())) { + return rewriter.notifyMatchFailure( + op, "only converts floating point division"); + } + + auto shiftTensorType = RankedTensorType::get({1}, rewriter.getI8Type()); + auto zeroShiftValue = DenseElementsAttr::get( + shiftTensorType, rewriter.getIntegerAttr(rewriter.getI8Type(), 0)); + auto shiftConst = rewriter.create( + op.getLoc(), shiftTensorType, zeroShiftValue); + + auto reciprocalOp = + rewriter.create(op.getLoc(), rhsType, op.getRhs()); + auto mulOp = rewriter.create( + op.getLoc(), op.getType(), op.getLhs(), reciprocalOp, shiftConst); + rewriter.replaceOp(op, mulOp.getResult()); + return success(); + } +}; + LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) { RewritePatternSet patternList(ctx); populateGeneratedPDLLPatterns(patternList); @@ -543,6 +576,8 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) { patternList.addWithLabel({"StablehloWhile"}, ctx); patternList.addWithLabel({"StablehloReshape"}, ctx); + patternList.addWithLabel( + {"StablehloFloatDivide"}, ctx); patterns = std::move(patternList); return success(); From 46d0634047bd83bc8df3ae2eeb581feeeb1c5bd5 Mon Sep 17 00:00:00 2001 From: vvish Date: Sat, 15 Nov 2025 00:37:22 +0100 Subject: [PATCH 2/2] Use Op::create instead of deprecated Builder::create --- .../tosa/transforms/StablehloLegalizeToTosa.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index c2492d08fc..be960db6b5 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -544,13 +544,15 @@ struct ConvertStablehloFloatDivideOp auto shiftTensorType = RankedTensorType::get({1}, rewriter.getI8Type()); auto zeroShiftValue = DenseElementsAttr::get( shiftTensorType, rewriter.getIntegerAttr(rewriter.getI8Type(), 0)); - auto shiftConst = rewriter.create( - op.getLoc(), shiftTensorType, zeroShiftValue); + auto shiftConst = tosa::ConstOp::create(rewriter, op.getLoc(), + shiftTensorType, zeroShiftValue); auto reciprocalOp = - rewriter.create(op.getLoc(), rhsType, op.getRhs()); - auto mulOp = rewriter.create( - op.getLoc(), op.getType(), op.getLhs(), reciprocalOp, shiftConst); + tosa::ReciprocalOp::create(rewriter, op.getLoc(), rhsType, op.getRhs()); + + auto mulOp = tosa::MulOp::create(rewriter, op.getLoc(), op.getType(), + op.getLhs(), reciprocalOp, shiftConst); + rewriter.replaceOp(op, mulOp.getResult()); return success(); }