Skip to content

Commit c9466d4

Browse files
Fix lit failures after 96e53bb
Signed-off-by: Whitney Tsang <[email protected]>
1 parent edfbc64 commit c9466d4

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,8 @@ void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
13961396
LogicalResult
13971397
SliceEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
13981398
unsigned dim, DistributedEncodingTrait parent) {
1399+
if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH"))
1400+
return success();
13991401
unsigned rank = cast<LayoutEncodingTrait>(parent).getRank();
14001402
if (rank <= 1)
14011403
return emitError() << "parent layout must have at least rank >= 2";
@@ -2141,9 +2143,11 @@ LogicalResult DotOperandEncodingAttr::verify(
21412143

21422144
if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
21432145
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
2144-
kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2)
2146+
kWidth != 4 && kWidth != 8 && kWidth != 16 &&
2147+
parentAttr.getVersion() == 2)
21452148
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
2146-
"gfx11 and 8/16 for gfx12";
2149+
"gfx11 and 4/8/16 for gfx12 (including packed "
2150+
"cases for `scaled_dot`)";
21472151
return success();
21482152
}
21492153

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,6 @@ static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type,
502502
MemDescType memdesc, StringRef regName) {
503503
if (type.getRank() != 2)
504504
return op->emitOpError(regName) << " must be a 2D tensor";
505-
if (isa<TensorMemoryScalesEncodingAttr>(memdesc.getEncoding()) &&
506-
!type.getElementType().isInteger(8)) {
507-
return op->emitOpError(regName)
508-
<< " expected to be a tensor of i8 for MMA scales encoding";
509-
}
510505
if (type.getEncoding()) {
511506
auto enc = dyn_cast<DistributedEncodingTrait>(type.getEncoding());
512507
if (!enc) {
@@ -526,6 +521,7 @@ static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type,
526521
<< " layout is not TMEM compatible";
527522
for (Attribute layout : layouts)
528523
diag.attachNote() << "potential TMEM layout: " << layout;
524+
return diag;
529525
}
530526
}
531527
return success();

0 commit comments

Comments
 (0)