Skip to content

Commit 698a0bd

Browse files
committed
Fix failiing lit tests
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 0a58d8e commit 698a0bd

File tree

2 files changed

+55
-59
lines changed

2 files changed

+55
-59
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -52,50 +52,52 @@ LogicalResult UpcastMXFPOp::verify() {
5252
}
5353

5454
/// TODO: Temporarily disabled this check to allow for the blocked encoding.
55-
/// we need to re-enable this check once we have the dot op encoding
56-
/// UpcastMXFPOp lowering
57-
// auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
58-
// if (!dotEncoding) {
59-
// return emitOpError("Expected a DotOperandEncodingAttr for values");
60-
// }
55+
/// Enable once we have the dot op encoding UpcastMXFPOp lowering.
56+
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
57+
if (mlir::triton::tools::getBoolEnv(
58+
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") &&
59+
!dotEncoding) {
60+
return emitOpError("Expected a DotOperandEncodingAttr for values");
61+
}
6162
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
6263
return emitOpError(
6364
"Expected a BlockOperandEncoding or LinearOperandEncoding "
6465
"for scales");
6566
}
67+
if (!dotEncoding)
68+
return success();
6669

67-
// if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
68-
// // Necessary to keep all of the scales of a given block of values in the
69-
// // same warp
70-
// auto threadsPerWarp =
71-
// cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
72-
// if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
73-
// return emitOpError("Expected threads per warp to be {16, 2}");
74-
// }
75-
// }
76-
77-
// // Change to support fp8 types
78-
// const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
79-
// // Figure out the K dimension for the input A/B. For A/B scale, the K
80-
// // dimension is always the last dimension.
81-
// const int opIdx = dotEncoding.getOpIdx();
82-
// const bool hasBatch = xShape.size() == 3;
83-
// const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
84-
85-
// if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
86-
// return emitOpError("K dimension of first operand must be 16 times "
87-
// "larger than last/K dimension of the second operand");
88-
// }
89-
90-
// // Check other dimensions match too. For input A/B, we need to figure out
91-
// the
92-
// // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
93-
// const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
94-
// if (hasBatch && xShape[0] != scaleShape[0])
95-
// return emitOpError("batch dimension must match between operands");
96-
// if (xShape[mnIdx] != scaleShape[hasBatch]) {
97-
// return emitOpError("M/N dimension must match between operands");
98-
// }
70+
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
71+
// Necessary to keep all of the scales of a given block of values in the
72+
// same warp
73+
auto threadsPerWarp =
74+
cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
75+
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
76+
return emitOpError("Expected threads per warp to be {16, 2}");
77+
}
78+
}
79+
80+
// Change to support fp8 types
81+
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
82+
// Figure out the K dimension for the input A/B. For A/B scale, the K
83+
// dimension is always the last dimension.
84+
const int opIdx = dotEncoding.getOpIdx();
85+
const bool hasBatch = xShape.size() == 3;
86+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
87+
88+
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
89+
return emitOpError("K dimension of first operand must be 16 times "
90+
"larger than last/K dimension of the second operand");
91+
}
92+
93+
// Check other dimensions match too. For input A/B, we need to figure out the
94+
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
95+
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
96+
if (hasBatch && xShape[0] != scaleShape[0])
97+
return emitOpError("batch dimension must match between operands");
98+
if (xShape[mnIdx] != scaleShape[hasBatch]) {
99+
return emitOpError("M/N dimension must match between operands");
100+
}
99101

100102
return success();
101103
}
@@ -110,8 +112,6 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
110112
auto xShape = xTy.getShape();
111113

112114
auto encoding = xTy.getEncoding();
113-
bool upcastMXFPUseDotOpEnc =
114-
mlir::triton::tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING");
115115

116116
if (typeEncoded == ScaleDotElemType::E2M1) {
117117
RankedTensorType retTy;
@@ -122,10 +122,8 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
122122
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
123123
} else {
124124
Type elemType = FloatType::getBF16(ctx);
125-
Attribute newVEncoding;
126-
if (upcastMXFPUseDotOpEnc) {
127-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
128-
125+
Attribute newVEncoding = nullptr;
126+
if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) {
129127
const int opIdx = oldEncoding.getOpIdx();
130128
const bool hasBatch = xShape.size() == 3;
131129
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
@@ -151,10 +149,9 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
151149
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
152150
oldEncoding.getKWidth() * 2);
153151
}
154-
} else {
155-
auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding);
156-
assert(oldEncoding &&
157-
"Expected a blocked encoding for UpcastMXFP op result.");
152+
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
153+
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
154+
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
158155
newShape.back() *= 2;
159156
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
160157
sizePerThread.back() *= 2;

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
251251
}
252252

253253
private:
254-
bool upcastMXFPUseDotOpEnc =
254+
const bool upcastMXFPUseDotOpEnc =
255255
mlir::triton::tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING");
256256

257257
struct OpDescriptor {
@@ -265,23 +265,22 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
265265
triton::gpu::intel::DpasEncodingAttr dpasEnc,
266266
RankedTensorType newRetType, ModuleOp mod,
267267
PatternRewriter &rewriter) const {
268+
assert((aDesc.scale || bDesc.scale) && "No scale provided");
269+
assert(!(aDesc.scale && bDesc.scale) && "NYI: Both LHS and RHS scale");
270+
268271
if (aDesc.scale) {
269-
assert(bDesc.scale == nullptr && "NYI: both LHS and RHS scale");
270272
TensorValue newA =
271273
convertScaledOperand<0>(aDesc, dpasEnc, newRetType, mod, rewriter);
272274
TensorValue newB =
273275
convertUnscaledOperand<1>(bDesc, dpasEnc, newRetType, rewriter);
274276
return {newA, newB};
275277
}
276-
if (bDesc.scale) {
277-
assert(aDesc.scale == nullptr && "NYI: both LHS and RHS scale");
278-
TensorValue newB =
279-
convertScaledOperand<1>(bDesc, dpasEnc, newRetType, mod, rewriter);
280-
TensorValue newA =
281-
convertUnscaledOperand<0>(aDesc, dpasEnc, newRetType, rewriter);
282-
return {newA, newB};
283-
}
284-
assert(false && "Both LHS and RHS unscaled");
278+
279+
TensorValue newB =
280+
convertScaledOperand<1>(bDesc, dpasEnc, newRetType, mod, rewriter);
281+
TensorValue newA =
282+
convertUnscaledOperand<0>(aDesc, dpasEnc, newRetType, rewriter);
283+
return {newA, newB};
285284
}
286285

287286
template <unsigned opIdx>

0 commit comments

Comments
 (0)