Skip to content

Commit 24f5385

Browse files
authored
[MLIR][NVVM] Support generating all the ldmatrix intrinsics from NVVM ops (llvm#148783)
Previously, the NVVM dialect's ldmatrix operation could only generate a limited subset of the available NVVM ldmatrix intrinsics. The intrinsics generating new ops introduced in BlackWell are not accessible through the NVVM ops. This commit extends the ldmatrix operation to support all available ldmatrix intrinsics.
1 parent e1a694c commit 24f5385

File tree

9 files changed

+287
-86
lines changed

9 files changed

+287
-86
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,13 +2070,16 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
20702070

20712071
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
20722072
Results<(outs AnyType:$res)>,
2073-
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
2073+
Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
2074+
MMALayoutAttr:$layout,
2075+
LdStMatrixShapeAttr:$shape,
2076+
LdStMatrixEltTypeAttr:$eltType)> {
20742077

20752078
let summary = "cooperative matrix load";
20762079

20772080
string llvmBuilder = [{
20782081
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
2079-
auto intId = getLdMatrixIntrinsicId($layout, $num);
2082+
auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $eltType);
20802083
$res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
20812084
}];
20822085

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
283283
Value srcPtr =
284284
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
285285
adaptor.getSrcMemref(), adaptor.getIndices());
286+
auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
286287
Value ldMatrixResult = NVVM::LdMatrixOp::create(
287288
b, ldMatrixResultType, srcPtr,
288289
/*num=*/op.getNumTiles(),
289290
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
290-
: NVVM::MMALayout::row);
291+
: NVVM::MMALayout::row,
292+
/*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
291293

292294
// The ldmatrix operation returns either a single i32 value or a struct of
293295
// i32 values. Here we unpack those values and cast them back to their

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -811,24 +811,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
811811
}
812812

813813
LogicalResult NVVM::LdMatrixOp::verify() {
814-
unsigned addressSpace =
815-
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
816-
if (addressSpace != NVVM::kSharedMemorySpace)
817-
return emitOpError("expected source pointer in memory space 3");
818-
819-
if (getNum() != 1 && getNum() != 2 && getNum() != 4)
820-
return emitOpError("expected num attribute to be 1, 2 or 4");
814+
uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
815+
if (m == 8 && n == 8) {
816+
if (num != 1 && num != 2 && num != 4) {
817+
return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
818+
"matrix");
819+
}
820+
if (getEltType() != LdStMatrixEltType::B16) {
821+
return emitOpError("expected element type to be b16 for 8x8 matrix");
822+
}
823+
} else if (m == 8 && n == 16) {
824+
if (num != 1 && num != 2 && num != 4) {
825+
return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
826+
"matrix");
827+
}
828+
if (getLayout() != MMALayout::row) {
829+
return emitOpError("expected layout to be row for 8x16 matrix");
830+
}
831+
if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
832+
getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
833+
return emitOpError("expected element type to be b8x16.b4x16_p64 or "
834+
"b8x16.b6x16_p32 for 8x16 matrix");
835+
}
836+
} else if (m == 16 && n == 16) {
837+
if (num != 1 && num != 2) {
838+
return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
839+
"matrix");
840+
}
841+
if (getLayout() != MMALayout::col) {
842+
return emitOpError("expected layout to be col for 16x16 matrix");
843+
}
844+
if (getEltType() != LdStMatrixEltType::B8 &&
845+
getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
846+
getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
847+
return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
848+
"b8x16.b6x16_p32 for 16x16 matrix");
849+
}
850+
} else {
851+
return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
852+
}
821853

822854
Type i32 = IntegerType::get(getContext(), 32);
823-
if (getNum() == 1 && getType() != i32)
855+
uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
856+
if (numElements == 1 && getType() != i32)
824857
return emitOpError("expected destination type is i32");
825-
if (getNum() == 2 || getNum() == 4) {
858+
if (numElements == 2 || numElements == 4) {
826859
Type dstType = LLVM::LLVMStructType::getLiteral(
827-
getContext(), SmallVector<Type>(getNum(), i32));
860+
getContext(), SmallVector<Type>(numElements, i32));
828861
if (getType() != dstType)
829862
return emitOpError("expected destination type is a structure of ")
830-
<< getNum() << " elements of type i32";
863+
<< numElements << " elements of type i32";
831864
}
865+
832866
return success();
833867
}
834868

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
135135
llvm_unreachable("unsupported vote kind");
136136
}
137137

138-
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
139-
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
140-
int32_t num) {
141-
if (layout == NVVM::MMALayout::row) {
138+
static llvm::Intrinsic::ID
139+
getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
140+
NVVM::LdStMatrixShapeAttr shape,
141+
NVVM::LdStMatrixEltType eltType) {
142+
if (shape.getM() == 8 && shape.getN() == 8) {
142143
switch (num) {
143144
case 1:
144-
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
145+
return (layout == NVVM::MMALayout::row)
146+
? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
147+
: llvm::Intrinsic::
148+
nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
145149
case 2:
146-
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
150+
return (layout == NVVM::MMALayout::row)
151+
? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
152+
: llvm::Intrinsic::
153+
nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
147154
case 4:
148-
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
149-
default:
150-
llvm_unreachable("unsupported number of matrix");
155+
return (layout == NVVM::MMALayout::row)
156+
? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
157+
: llvm::Intrinsic::
158+
nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
151159
}
152-
153-
} else {
154-
switch (num) {
155-
case 1:
156-
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
157-
case 2:
158-
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
159-
case 4:
160-
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
161-
default:
162-
llvm_unreachable("unsupported number of matrix");
160+
} else if (shape.getM() == 8 && shape.getN() == 16) {
161+
if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
162+
switch (num) {
163+
case 1:
164+
return llvm::Intrinsic::
165+
nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
166+
case 2:
167+
return llvm::Intrinsic::
168+
nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
169+
case 4:
170+
return llvm::Intrinsic::
171+
nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
172+
}
173+
} else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
174+
switch (num) {
175+
case 1:
176+
return llvm::Intrinsic::
177+
nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
178+
case 2:
179+
return llvm::Intrinsic::
180+
nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
181+
case 4:
182+
return llvm::Intrinsic::
183+
nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
184+
}
185+
}
186+
} else if (shape.getM() == 16 && shape.getN() == 16) {
187+
if (eltType == NVVM::LdStMatrixEltType::B8) {
188+
switch (num) {
189+
case 1:
190+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
191+
case 2:
192+
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
193+
}
194+
} else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
195+
switch (num) {
196+
case 1:
197+
return llvm::Intrinsic::
198+
nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
199+
case 2:
200+
return llvm::Intrinsic::
201+
nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
202+
}
203+
} else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
204+
switch (num) {
205+
case 1:
206+
return llvm::Intrinsic::
207+
nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
208+
case 2:
209+
return llvm::Intrinsic::
210+
nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
211+
}
163212
}
164213
}
214+
llvm_unreachable("unknown ldmatrix kind");
165215
}
166216

167217
/// Return the intrinsic ID associated with stmatrix for the given paramters.

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
159159
// CHECK-LABEL: @ldmatrix_x4
160160
func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
161161
%c0 = arith.constant 0 : index
162-
// CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
162+
// CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
163163
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
164164
// CHECK: llvm.extractvalue
165165
// CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
179179
// CHECK-LABEL: @ldmatrix_x1
180180
func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
181181
%c0 = arith.constant 0 : index
182-
// CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
182+
// CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
183183
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
184184
// CHECK: llvm.bitcast
185185
// CHECK: llvm.insertvalue

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,38 +1220,6 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
12201220

12211221
// -----
12221222

1223-
llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
1224-
// expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
1225-
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
1226-
llvm.return
1227-
}
1228-
1229-
// -----
1230-
1231-
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
1232-
// expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
1233-
%l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
1234-
llvm.return
1235-
}
1236-
1237-
// -----
1238-
1239-
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
1240-
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
1241-
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
1242-
llvm.return
1243-
}
1244-
1245-
// -----
1246-
1247-
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
1248-
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
1249-
%l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
1250-
llvm.return
1251-
}
1252-
1253-
// -----
1254-
12551223
llvm.func @caller() {
12561224
// expected-error @below {{expected function call to produce a value}}
12571225
llvm.call @callee() : () -> ()

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
385385
llvm.return
386386
}
387387

388-
// CHECK-LABEL: llvm.func @ld_matrix
389-
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
390-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
391-
%l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
392-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
393-
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
394-
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
395-
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
396-
llvm.return
397-
}
398-
399388
// CHECK-LABEL: llvm.func @redux_sync
400389
llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
401390
// CHECK: nvvm.redux.sync add %{{.*}}

0 commit comments

Comments
 (0)