Skip to content

Commit 1c2f01e

Browse files
authored
[MLIR][NVVM] Fixed assertion failure for insufficient parsing validation of nvvm dialect MMAOp (#163432)
1 parent bf5f441 commit 1c2f01e

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,26 @@ LogicalResult MmaOp::verify() {
798798
" attribute");
799799
}
800800

801+
// Validate layout combinations. According to the operation description, most
802+
// MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
803+
// can use other layout combinations.
804+
bool isM8N8K4_F16 =
805+
(mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
806+
getMultiplicandAPtxType() == MMATypes::f16);
807+
808+
if (!isM8N8K4_F16) {
809+
// For all other shapes/types, layoutA must be row and layoutB must be col
810+
if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
811+
return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
812+
"layoutB = #nvvm.mma_layout<col> for shape <")
813+
<< mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
814+
<< "> with element types "
815+
<< stringifyEnum(*getMultiplicandAPtxType()) << " and "
816+
<< stringifyEnum(*getMultiplicandBPtxType())
817+
<< ". Only m8n8k4 with f16 supports other layouts.";
818+
}
819+
}
820+
801821
return success();
802822
}
803823

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,18 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
568568
llvm.return
569569
}
570570

571+
// -----
572+
573+
// Test that ensures invalid row/col layouts for matrices A and B are not accepted
574+
llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
575+
// expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}}
576+
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
577+
{layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<col>,
578+
multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
579+
intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
580+
shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
581+
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
582+
}
571583

572584
// -----
573585

0 commit comments

Comments
 (0)