Skip to content

Commit b3e8d2f

Browse files
committed
Fixed assertion failure for insufficient parsing validation of nvvm dialect with MMAOp
1 parent b358af1 commit b3e8d2f

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,26 @@ LogicalResult MmaOp::verify() {
787787
" attribute");
788788
}
789789

790+
// Validate layout combinations. According to the operation description, most
791+
// MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
792+
// can use other layout combinations.
793+
bool isM8N8K4_F16 =
794+
(mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
795+
getMultiplicandAPtxType() == MMATypes::f16);
796+
797+
if (!isM8N8K4_F16) {
798+
// For all other shapes/types, layoutA must be row and layoutB must be col
799+
if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
800+
return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
801+
"layoutB = #nvvm.mma_layout<col> for shape <")
802+
<< mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
803+
<< "> with element types "
804+
<< stringifyEnum(*getMultiplicandAPtxType()) << " and "
805+
<< stringifyEnum(*getMultiplicandBPtxType())
806+
<< ". Only m8n8k4 with f16 supports other layouts.";
807+
}
808+
}
809+
790810
return success();
791811
}
792812

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,16 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
559559
%res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1
560560
llvm.return
561561
}
562+
563+
// -----
564+
565+
// Test that ensures invalid row/col layouts for matrices A and B are not accepted
566+
llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
567+
// expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}}
568+
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
569+
{layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<col>,
570+
multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
571+
intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
572+
shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
573+
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
574+
}

0 commit comments

Comments
 (0)