Skip to content

Commit 1112516

Browse files
authored
Remove DecomposeScaledBlocked pass i32 comparation with 0xff (#4032)
This change would make DecomposeScaledBlocked pass same logic with the upstream one.
1 parent 8fbd57c commit 1112516

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/DecomposeScaledBlocked.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,31 +148,19 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
148148
DotScaledOp scaledDotOp, ModuleOp mod,
149149
TypedValue<RankedTensorType> mxfp,
150150
TypedValue<RankedTensorType> scale,
151-
FloatType computeType, int dim) const {
151+
int dim) const {
152152
// Implement tl.where(scale == 0xFF, float("nan"), mxfp)
153153
auto loc = scale.getLoc();
154154

155-
// FIXME: use large int type (int32) for comparing with 0xFF to avoid
156-
// accidently masking non-NaN values to NaN.
157-
// This piece of code will be removed after
158-
// https://github.com/intel/intel-xpu-backend-for-triton/issues/3605
159-
FloatType largeFpType = computeType == rewriter.getF16Type()
160-
? rewriter.getF32Type()
161-
: computeType;
162-
int intWidth = largeFpType.getIntOrFloatBitWidth();
163-
auto intType = rewriter.getIntegerType(intWidth);
164-
// Use large int scale type, incase it get nonNaN to NaN
165-
auto scaleTy = scale.getType().clone(intType);
166-
auto zexted = rewriter.create<arith::ExtUIOp>(loc, scaleTy, scale);
167-
168155
// Scale is NaN
156+
auto scaleTy = scale.getType();
169157
auto constFF = rewriter.create<arith::ConstantOp>(
170158
loc, scaleTy,
171159
DenseElementsAttr::get(scaleTy,
172160
APInt(scaleTy.getElementTypeBitWidth(), 0xff)));
173161
auto scaleIsNan = cast<TypedValue<RankedTensorType>>(
174162
rewriter
175-
.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, zexted,
163+
.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, scale,
176164
constFF)
177165
.getResult());
178166
auto cond = broadcastScale(rewriter, scaledDotOp, mod, scaleIsNan, dim);
@@ -247,7 +235,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
247235
return mxfp;
248236

249237
// 4) If the scale is NaN, return NaN, else return the scaled value.
250-
return maskNan(rewriter, scaledDotOp, mod, mxfp, scale, computeType, kDim);
238+
return maskNan(rewriter, scaledDotOp, mod, mxfp, scale, kDim);
251239
}
252240
};
253241

0 commit comments

Comments
 (0)