Skip to content

Commit 49995b2

Browse files
authored
[MLIR][GPU] subgroup_mma fp64 extension (#165873)
This PR extends the `gpu.subgroup_mma_*` ops to support fp64 type. The extension requires special handling during the lowering to `nvvm` due to the return type for load ops for fragment a and b (they return a scalar instead of a struct).
1 parent 31a552d commit 49995b2

File tree

8 files changed

+146
-20
lines changed

8 files changed

+146
-20
lines changed

mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class MMAMatrixType;
2727
#define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS
2828
#include "mlir/Conversion/Passes.h.inc"
2929

30-
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
30+
Type convertMMAToLLVMType(gpu::MMAMatrixType type);
3131

3232
/// Configure target to convert from the GPU dialect to NVVM.
3333
void configureGpuToNVVMConversionLegality(ConversionTarget &target);

mlir/include/mlir/Dialect/GPU/IR/GPUBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType<
114114
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
115115

116116
// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
117-
def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
117+
def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>;
118118

119119
class MMAMatrixOf<list<Type> allowedTypes> :
120120
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
18721872
```
18731873
}];
18741874

1875-
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
1875+
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32, F64]>>:$src,
18761876
Arg<GPU_MMAMemRef, "",[MemWriteAt<0, FullEffect>]>:$dstMemref,
18771877
Variadic<Index>:$indices,
18781878
IndexAttr:$leadDimension,
@@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp
19191919
```
19201920
}];
19211921

1922-
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
1923-
Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
1924-
Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
1922+
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opA,
1923+
Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opB,
1924+
Arg<MMAMatrixOf<[I32, F16, F32, F64]>>:$opC,
19251925
OptionalAttr<UnitAttr>:$a_transpose,
19261926
OptionalAttr<UnitAttr>:$b_transpose);
19271927

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1818
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1919
#include "mlir/IR/TypeUtilities.h"
20+
#include "mlir/IR/Types.h"
2021

2122
using namespace mlir;
2223

@@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
5758
if (type.getElementType().isF32())
5859
return type.getOperand() == "COp" ? NVVM::MMATypes::f32
5960
: NVVM::MMATypes::tf32;
60-
61+
if (type.getElementType().isF64())
62+
return NVVM::MMATypes::f64;
6163
if (type.getElementType().isSignedInteger(8))
6264
return NVVM::MMATypes::s8;
6365
if (type.getElementType().isUnsignedInteger(8))
@@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
212214
// then passed on to the intrinsic call. Emit llvm ops to extract individual
213215
// values form lowered memrefs.
214216
SmallVector<Value> unpackedOps;
215-
216217
auto unpackOp = [&](Value operand) {
218+
// f64 a and b fragments are not structs but scalars.
219+
if (!isa<LLVM::LLVMStructType>(operand.getType())) {
220+
unpackedOps.push_back(operand);
221+
return;
222+
}
223+
// every other type is lowered to an LLVM struct, extract the values.
217224
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
218225
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
219226
Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
@@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
276283
return failure();
277284
Location loc = subgroupMmaConstantOp.getLoc();
278285
Value cst = adaptor.getOperands()[0];
279-
LLVM::LLVMStructType type = convertMMAToLLVMType(
286+
Type type = convertMMAToLLVMType(
280287
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
288+
// If the element is not a struct, it means it's a scalar f64.
289+
auto structType = dyn_cast<LLVM::LLVMStructType>(type);
290+
if (!structType) {
291+
rewriter.replaceOp(subgroupMmaConstantOp, cst);
292+
return success();
293+
}
281294
// If the element type is a vector create a vector from the operand.
282-
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
295+
if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
283296
Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
284297
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
285298
Value idx = LLVM::ConstantOp::create(rewriter, loc,
@@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
289302
}
290303
cst = vecCst;
291304
}
292-
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
293-
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
305+
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
306+
for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
294307
matrixStruct =
295308
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
296309
}
@@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
354367
return failure();
355368
Location loc = subgroupMmaElementwiseOp.getLoc();
356369
size_t numOperands = adaptor.getOperands().size();
357-
LLVM::LLVMStructType destType = convertMMAToLLVMType(
370+
Type destType = convertMMAToLLVMType(
358371
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
359-
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
360-
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
372+
373+
// If the element is not a struct, it means it's a scalar f64.
374+
LLVM::LLVMStructType structDestTy =
375+
dyn_cast<LLVM::LLVMStructType>(destType);
376+
if (!structDestTy) {
377+
SmallVector<Value> operands;
378+
for (auto operand : adaptor.getOperands()) {
379+
operands.push_back(operand);
380+
}
381+
Value element = createScalarOp(
382+
rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
383+
rewriter.replaceOp(subgroupMmaElementwiseOp, element);
384+
return success();
385+
}
386+
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
387+
for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
361388
SmallVector<Value> extractedOperands;
362389
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
363390
extractedOperands.push_back(LLVM::ExtractValueOp::create(
@@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
377404
} // namespace
378405

379406
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
380-
LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
407+
Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
381408
NVVM::MMAFrag frag = convertOperand(type.getOperand());
382409
NVVM::MMATypes eltType = getElementType(type);
383410
auto nRow = type.getShape()[0];
384411
auto nCol = type.getShape()[1];
385412
std::pair<Type, unsigned> typeInfo =
386413
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
414+
// Special handling for f64 a and b fragments
415+
Type f64Ty = Float64Type::get(type.getContext());
416+
if (typeInfo.first == f64Ty && typeInfo.second == 1) {
417+
return f64Ty;
418+
}
387419
return LLVM::LLVMStructType::getLiteral(
388420
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
389421
}

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
208208
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
209209

210210
bool MMAMatrixType::isValidElementType(Type elementType) {
211-
return elementType.isF16() || elementType.isF32() ||
211+
return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
212212
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
213213
elementType.isInteger(32);
214214
}
@@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
225225

226226
if (!MMAMatrixType::isValidElementType(elementType))
227227
return emitError()
228-
<< "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
228+
<< "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
229229

230230
return success();
231231
}

mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,28 @@ gpu.module @test_module {
8080

8181
// -----
8282

83+
gpu.module @test_module {
84+
85+
// CHECK-LABEL: func @gpu_wmma_f64_load_op() ->
86+
// CHECK-SAME: f64
87+
// CHECK32-LABEL: func @gpu_wmma_f64_load_op() ->
88+
func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) {
89+
%wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3>
90+
%i = arith.constant 16 : index
91+
%j = arith.constant 16 : index
92+
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp">
93+
return %0 : !gpu.mma_matrix<8x4xf64, "AOp">
94+
// CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64
95+
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
96+
// CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64
97+
// CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32
98+
// CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64
99+
// CHECK: llvm.return %[[LOAD]] : f64
100+
}
101+
}
102+
103+
// -----
104+
83105
gpu.module @test_module {
84106

85107
// CHECK-LABEL: func @gpu_wmma_store_op

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ func.func @mmamatrix_operand_type(){
688688
func.func @mmamatrix_invalid_element_type(){
689689
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
690690
%i = arith.constant 16 : index
691-
// expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}}
691+
// expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64}}
692692
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp">
693693
return
694694
}
@@ -708,7 +708,7 @@ func.func @mmaLoadOp_identity_layout(){
708708
// -----
709709

710710
func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
711-
// expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}}
711+
// expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float values of ranks 1 values}}
712712
%0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
713713
return
714714
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \
3+
// RUN: | mlir-runner \
4+
// RUN: --shared-libs=%mlir_cuda_runtime \
5+
// RUN: --shared-libs=%mlir_runner_utils \
6+
// RUN: --entry-point-result=void \
7+
// RUN: | FileCheck %s
8+
9+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
10+
11+
func.func @main() {
12+
%a = memref.alloc() : memref<8x4xf64>
13+
%b = memref.alloc() : memref<4x8xf64>
14+
%c = memref.alloc() : memref<8x8xf64>
15+
%d = memref.alloc() : memref<8x8xf64>
16+
17+
%f1 = arith.constant 1.0e+00 : f64
18+
%fcst = arith.constant 3.14e+00 : f64
19+
%c0 = arith.constant 0 : index
20+
%c8 = arith.constant 8 : index
21+
%c4 = arith.constant 4 : index
22+
%c1 = arith.constant 1 : index
23+
%c32 = arith.constant 32 : index
24+
25+
// Initialize the Input matrixes with ones.
26+
scf.for %arg0 = %c0 to %c8 step %c1 {
27+
scf.for %arg1 = %c0 to %c4 step %c1 {
28+
memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64>
29+
memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64>
30+
}
31+
}
32+
// Initialize the accumulator matrix with a constant.
33+
scf.for %arg0 = %c0 to %c8 step %c1 {
34+
scf.for %arg1 = %c0 to %c8 step %c1 {
35+
memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64>
36+
}
37+
}
38+
39+
%2 = memref.cast %a : memref<8x4xf64> to memref<*xf64>
40+
%20 = memref.cast %b : memref<4x8xf64> to memref<*xf64>
41+
%33 = memref.cast %c : memref<8x8xf64> to memref<*xf64>
42+
%34 = memref.cast %d : memref<8x8xf64> to memref<*xf64>
43+
44+
gpu.host_register %2 : memref<*xf64>
45+
gpu.host_register %20 : memref<*xf64>
46+
gpu.host_register %33 : memref<*xf64>
47+
48+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
49+
threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
50+
%A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp">
51+
%B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp">
52+
%C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp">
53+
54+
%R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp">
55+
56+
gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64>
57+
gpu.terminator
58+
}
59+
// Print the memref after computation.
60+
call @printMemrefF64(%34) : (memref<*xf64>) -> ()
61+
// CHECK: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
62+
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
63+
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
64+
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
65+
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
66+
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
67+
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
68+
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14]
69+
return
70+
}
71+
72+
func.func private @printMemrefF64(memref<*xf64>)

0 commit comments

Comments
 (0)