@@ -148,31 +148,19 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
148148 DotScaledOp scaledDotOp, ModuleOp mod,
149149 TypedValue<RankedTensorType> mxfp,
150150 TypedValue<RankedTensorType> scale,
151- FloatType computeType, int dim) const {
151+ int dim) const {
152152 // Implement tl.where(scale == 0xFF, float("nan"), mxfp)
153153 auto loc = scale.getLoc ();
154154
155- // FIXME: use large int type (int32) for comparing with 0xFF to avoid
156- // accidently masking non-NaN values to NaN.
157- // This piece of code will be removed after
158- // https://github.com/intel/intel-xpu-backend-for-triton/issues/3605
159- FloatType largeFpType = computeType == rewriter.getF16Type ()
160- ? rewriter.getF32Type ()
161- : computeType;
162- int intWidth = largeFpType.getIntOrFloatBitWidth ();
163- auto intType = rewriter.getIntegerType (intWidth);
164- // Use large int scale type, incase it get nonNaN to NaN
165- auto scaleTy = scale.getType ().clone (intType);
166- auto zexted = rewriter.create <arith::ExtUIOp>(loc, scaleTy, scale);
167-
168155 // Scale is NaN
156+ auto scaleTy = scale.getType ();
169157 auto constFF = rewriter.create <arith::ConstantOp>(
170158 loc, scaleTy,
171159 DenseElementsAttr::get (scaleTy,
172160 APInt (scaleTy.getElementTypeBitWidth (), 0xff )));
173161 auto scaleIsNan = cast<TypedValue<RankedTensorType>>(
174162 rewriter
175- .create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq, zexted ,
163+ .create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq, scale ,
176164 constFF)
177165 .getResult ());
178166 auto cond = broadcastScale (rewriter, scaledDotOp, mod, scaleIsNan, dim);
@@ -247,7 +235,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
247235 return mxfp;
248236
249237 // 4) If the scale is NaN, return NaN, else return the scaled value.
250- return maskNan (rewriter, scaledDotOp, mod, mxfp, scale, computeType, kDim );
238+ return maskNan (rewriter, scaledDotOp, mod, mxfp, scale, kDim );
251239 }
252240};
253241
0 commit comments