@@ -398,6 +398,7 @@ class DecomposeScaledBlocked
398398 auto scale = scaledDotOp.getLhsScale ();
399399 auto aType = scaledDotOp.getLhsType ();
400400 auto bType = scaledDotOp.getRhsType ();
401+ bool fastMath = scaledDotOp.getFastMath ();
401402
402403 auto rank = oldRetType.getShape ().size ();
403404 if (rank != 2 )
@@ -510,15 +511,17 @@ class DecomposeScaledBlocked
510511 newScaleEncoding = LinearEncodingAttr::get (ctx, std::move (newLL));
511512 }
512513
513- a = createArg (rewriter, a, 0 , aType, newAEncoding, scale, newScaleEncoding);
514+ a = createArg (rewriter, a, 0 , aType, newAEncoding, scale, newScaleEncoding,
515+ fastMath);
514516
515517 Operation *newDot = nullptr ;
516518 if (versionMajor == 2 ) {
517519 // Upcast B operand
518520 assert (bType != ScaleDotElemType::E2M1 && " NYI: rhs scale for fp4" );
519521 auto newBEncoding = DotOperandEncodingAttr::get (ctx, 1 , mmaEnc, bKWidth);
520522 b = createArg (rewriter, b, 1 , bType, newBEncoding,
521- /* scale=*/ std::nullopt , /* scaleEncoding=*/ std::nullopt );
523+ /* scale=*/ std::nullopt , /* scaleEncoding=*/ std::nullopt ,
524+ fastMath);
522525 newDot = rewriter.create <DotOp>(scaledDotOp.getLoc (), newRetType, a, b,
523526 newAcc);
524527 } else {
@@ -541,7 +544,7 @@ class DecomposeScaledBlocked
541544 createArg (mlir::PatternRewriter &rewriter, TypedValue<RankedTensorType> v,
542545 int idx, ScaleDotElemType type, std::optional<Attribute> vEncoding,
543546 std::optional<TypedValue<RankedTensorType>> opt_scale,
544- std::optional<Attribute> scaleEncoding) const {
547+ std::optional<Attribute> scaleEncoding, bool fastMath ) const {
545548 auto ctx = rewriter.getContext ();
546549 // Create a new tensor with a given encoding or remove the encoding
547550 auto maybeWithEncoding =
@@ -576,7 +579,7 @@ class DecomposeScaledBlocked
576579 auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType (
577580 ret, type, Builder (v.getContext ()).getBF16Type ());
578581 ret = rewriter.create <triton::gpu::UpcastMXFPOp>(v.getLoc (), retTy, ret,
579- scale, type);
582+ scale, type, fastMath );
580583 }
581584 return ret;
582585 }
@@ -589,6 +592,7 @@ class DecomposeScaledBlocked
589592 auto scale = scaledDotOp.getLhsScale ();
590593 auto aType = scaledDotOp.getLhsType ();
591594 auto bType = scaledDotOp.getRhsType ();
595+ bool fastMath = scaledDotOp.getFastMath ();
592596
593597 // create a DotOp to be passed in to getMMAVersionSafe
594598 // We don't pass encodings as we just want to get the type and shape
@@ -597,15 +601,16 @@ class DecomposeScaledBlocked
597601 // end up in the graph
598602 RankedTensorType aTType =
599603 createArg (rewriter, a, 0 , aType, /* vEncoding=*/ std::nullopt , scale,
600- /* scaleEncoding=*/ std::nullopt )
604+ /* scaleEncoding=*/ std::nullopt , fastMath )
601605 .getType ();
602606 auto aTypeNoEnc =
603607 RankedTensorType::get (aTType.getShape (), aTType.getElementType ());
604608 a = rewriter.create <ConvertLayoutOp>(scaledDotOp.getLoc (), aTypeNoEnc, a);
605609
606610 RankedTensorType bTType =
607611 createArg (rewriter, b, 1 , bType, /* vEncoding=*/ std::nullopt ,
608- /* scale=*/ std::nullopt , /* scaleEncoding=*/ std::nullopt )
612+ /* scale=*/ std::nullopt , /* scaleEncoding=*/ std::nullopt ,
613+ fastMath)
609614 .getType ();
610615 auto bTypeNoEnc =
611616 RankedTensorType::get (bTType.getShape (), bTType.getElementType ());
@@ -752,7 +757,7 @@ static Operation *transposeDotOp(DotScaledOp dotOp) {
752757 Value result = builder.create <DotScaledOp>(
753758 dotOp.getLoc (), cTransposed.getType (), rhsTransposed, lhsTransposed,
754759 cTransposed, dotOp.getRhsScale (), dotOp.getLhsScale (), dotOp.getRhsType (),
755- dotOp.getLhsType ());
760+ dotOp.getLhsType (), dotOp. getFastMath () );
756761 Operation *transposedResult =
757762 builder.create <TransOp>(result.getLoc (), result, transOrder);
758763 dotOp.replaceAllUsesWith (transposedResult);
0 commit comments