@@ -523,6 +523,39 @@ struct ConvertStablehloReshapeOp
523523 }
524524};
525525
526+ struct ConvertStablehloFloatDivideOp
527+ : public OpRewritePattern<stablehlo::DivOp> {
528+ using OpRewritePattern<stablehlo::DivOp>::OpRewritePattern;
529+
530+ LogicalResult matchAndRewrite (stablehlo::DivOp op,
531+ PatternRewriter& rewriter) const override {
532+ auto lhsType = dyn_cast<RankedTensorType>(op.getLhs ().getType ());
533+ auto rhsType = dyn_cast<RankedTensorType>(op.getRhs ().getType ());
534+ if (!lhsType || !rhsType) {
535+ return rewriter.notifyMatchFailure (op, " expected ranked tensor types" );
536+ }
537+
538+ if (!llvm::isa<mlir::FloatType>(lhsType.getElementType ()) &&
539+ !llvm::isa<mlir::FloatType>(rhsType.getElementType ())) {
540+ return rewriter.notifyMatchFailure (
541+ op, " only converts floating point division" );
542+ }
543+
544+ auto shiftTensorType = RankedTensorType::get ({1 }, rewriter.getI8Type ());
545+ auto zeroShiftValue = DenseElementsAttr::get (
546+ shiftTensorType, rewriter.getIntegerAttr (rewriter.getI8Type (), 0 ));
547+ auto shiftConst = rewriter.create <tosa::ConstOp>(
548+ op.getLoc (), shiftTensorType, zeroShiftValue);
549+
550+ auto reciprocalOp =
551+ rewriter.create <tosa::ReciprocalOp>(op.getLoc (), rhsType, op.getRhs ());
552+ auto mulOp = rewriter.create <tosa::MulOp>(
553+ op.getLoc (), op.getType (), op.getLhs (), reciprocalOp, shiftConst);
554+ rewriter.replaceOp (op, mulOp.getResult ());
555+ return success ();
556+ }
557+ };
558+
526559LogicalResult StablehloLegalizeToTosaPass::initialize (MLIRContext* ctx) {
527560 RewritePatternSet patternList (ctx);
528561 populateGeneratedPDLLPatterns (patternList);
@@ -543,6 +576,8 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
543576 patternList.addWithLabel <ConvertStablehloWhileOp>({" StablehloWhile" }, ctx);
544577 patternList.addWithLabel <ConvertStablehloReshapeOp>({" StablehloReshape" },
545578 ctx);
579+ patternList.addWithLabel <ConvertStablehloFloatDivideOp>(
580+ {" StablehloFloatDivide" }, ctx);
546581
547582 patterns = std::move (patternList);
548583 return success ();
0 commit comments