@@ -135,11 +135,16 @@ TypedValue<RankedTensorType> DecomposeScaledBlocked::broadcastScale(
135135}
136136
137137TypedValue<RankedTensorType> DecomposeScaledBlocked::maskNan (
138- PatternRewriter &rewriter, DotScaledOp scaledDotOp, ModuleOp mod,
138+ PatternRewriter &rewriter, DotScaledOp scaledDotOp,
139139 TypedValue<RankedTensorType> mxfp, TypedValue<RankedTensorType> scale,
140140 int dim) const {
141+ // Skip NaN checks if fastMath
142+ if (scaledDotOp.getFastMath ())
143+ return mxfp;
144+
141145 // Implement tl.where(scale == 0xFF, float("nan"), mxfp)
142146 auto loc = scale.getLoc ();
147+ auto mod = scaledDotOp->getParentOfType <ModuleOp>();
143148
144149 // Scale is NaN
145150 auto scaleTy = scale.getType ();
@@ -180,7 +185,6 @@ DecomposeScaledBlocked::scaleArg(PatternRewriter &rewriter,
180185 auto fastMath = scaledDotOp.getFastMath ();
181186
182187 auto loc = v.getLoc ();
183- auto mod = scaledDotOp->getParentOfType <ModuleOp>();
184188 auto rank = v.getType ().getRank ();
185189 auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
186190
@@ -196,9 +200,33 @@ DecomposeScaledBlocked::scaleArg(PatternRewriter &rewriter,
196200 if (!scale)
197201 return v;
198202
203+ // 1) Cast scale to fp16/bf16, broadcast it and convert its layout
204+ auto reshapeScale = extendAndBroadcastScale (rewriter, scaledDotOp, scale,
205+ computeType, v.getType (), opIdx);
206+
207+ // 2) Multiply
208+ auto mxfp = cast<TypedValue<RankedTensorType>>(
209+ rewriter.create <arith::MulFOp>(loc, v, reshapeScale).getResult ());
210+
211+ // 3) If the scale is NaN, return NaN, else return the scaled value.
212+ return maskNan (rewriter, scaledDotOp, mxfp, scale, kDim );
213+ }
214+
215+ TypedValue<RankedTensorType> DecomposeScaledBlocked::extendAndBroadcastScale (
216+ PatternRewriter &rewriter, DotScaledOp scaledDotOp,
217+ TypedValue<RankedTensorType> &scale, FloatType computeType,
218+ RankedTensorType dstType, int opIdx) const {
219+ auto loc = scale.getLoc ();
220+ auto mod = scaledDotOp->getParentOfType <ModuleOp>();
221+ auto v = opIdx == 0 ? scaledDotOp.getA () : scaledDotOp.getB ();
222+ auto rank = v.getType ().getRank ();
223+ auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
224+
199225 // For some weird reason, we take the scale with shape as if it were coming
200226 // from the lhs even when it's the rhs. In a normal world, we should accept
201- // this parametre transposed, as we do with the mxfp.
227+ // this parameter transposed, as we do with the mxfp.
228+ //
229+ // Notice: this is an inplace change.
202230 if (opIdx == 1 ) {
203231 auto order = getTransposeOrder (rank);
204232 scale = rewriter.create <TransOp>(loc, scale, order);
@@ -207,21 +235,9 @@ DecomposeScaledBlocked::scaleArg(PatternRewriter &rewriter,
207235 // 1) Cast scale to compute type (fp16/bf16)
208236 auto scale16 = scaleTo16 (rewriter, scale, computeType);
209237
210- // 2) Broadcast scale to the same shape and layout as v
238+ // 2) Broadcast scale to the same shape as v and convert the layout
211239 auto reshapeScale = broadcastScale (rewriter, scaledDotOp, mod, scale16, kDim );
212- reshapeScale =
213- rewriter.create <ConvertLayoutOp>(loc, v.getType (), reshapeScale);
214-
215- // 3) Multiply
216- auto mxfp = cast<TypedValue<RankedTensorType>>(
217- rewriter.create <arith::MulFOp>(loc, v, reshapeScale).getResult ());
218-
219- // Skip NaN checks if fastMath
220- if (fastMath)
221- return mxfp;
222-
223- // 4) If the scale is NaN, return NaN, else return the scaled value.
224- return maskNan (rewriter, scaledDotOp, mod, mxfp, scale, kDim );
240+ return rewriter.create <ConvertLayoutOp>(loc, dstType, reshapeScale);
225241}
226242
227243TypedValue<RankedTensorType>
0 commit comments