Skip to content

Commit 0ed8e66

Browse files
authored
[MLIR][NVVM] Extend NVVM mma ops to support fp64 (#165380)
This PR extends the `nvvm.mma` ops to support fp64 type. The extension requires special handling of the return type for load ops for fragment `a` and `b` since they return a scalar instead of a struct.
1 parent 511c9c0 commit 0ed8e66

File tree

4 files changed

+76
-2
lines changed

4 files changed

+76
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,9 @@ class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
19991999
// llvm supports and can be extended as needed.
20002000
class NVVM_MMA_OPS {
20012001
// "wmma" operations
2002+
list<list<WMMA_REGS>> fp64_wmma_ops = MMA_OPS<
2003+
[GEOM<8, 8, 4>],
2004+
["f64"], [], ["f64"], []>.ret;
20022005
list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
20032006
[GEOM<16, 16, 8>],
20042007
["tf32"], [], ["f32"], []>.ret;
@@ -2009,6 +2012,7 @@ class NVVM_MMA_OPS {
20092012
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
20102013
["s8","u8"], [], ["s32"], []>.ret;
20112014
list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
2015+
fp64_wmma_ops,
20122016
tf32_wmma_ops,
20132017
fp_wmma_ops,
20142018
i8_wmma_ops);
@@ -2025,9 +2029,17 @@ class NVVM_MMA_OPS {
20252029
list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
20262030
[GEOM<16, 16, 8>],
20272031
["c", "d"], ["f32"]>.ret;
2032+
list<WMMA_REGS> ldst_f64_ab_ops = MMA_LDST_OPS<
2033+
[GEOM<8, 8, 4>],
2034+
["a", "b"], ["f64"]>.ret;
2035+
list<WMMA_REGS> ldst_f64_cd_ops = MMA_LDST_OPS<
2036+
[GEOM<8, 8, 4>],
2037+
["c", "d"], ["f64"]>.ret;
20282038
list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
20292039
ldst_tf32_ab_ops,
2030-
ldst_tf32_cd_ops);
2040+
ldst_tf32_cd_ops,
2041+
ldst_f64_ab_ops,
2042+
ldst_f64_cd_ops);
20312043
// Separate A/B/C fragments (loads) from D (stores).
20322044
list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
20332045
list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
@@ -2334,7 +2346,7 @@ def MMAFragAttr : EnumAttr<NVVM_Dialect, MMAFrag, "mma_frag"> {
23342346
}
23352347

23362348
def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">,
2337-
Results<(outs LLVM_AnyStruct:$res)>,
2349+
Results<(outs AnyTypeOf<[LLVM_AnyStruct, F64]>:$res)>,
23382350
Arguments<(ins LLVM_AnyPointer: $ptr, I32: $stride, I32Attr:$m,
23392351
I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout,
23402352
MMATypesAttr:$eltype, MMAFragAttr:$frag)> {

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
896896
} else if (type == NVVM::MMATypes::f32) {
897897
elementType = builder.getF32Type();
898898
numberElements = 8;
899+
} else if (type == NVVM::MMATypes::f64) {
900+
elementType = builder.getF64Type();
901+
if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
902+
numberElements = 1;
903+
else
904+
numberElements = 2;
899905
} else if (type == NVVM::MMATypes::tf32) {
900906
elementType = builder.getI32Type();
901907
numberElements = 4;
@@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {
954960
return emitOpError() << "invalid attribute combination";
955961
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
956962
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
963+
// Special case for f64 fragments
964+
Type f64Ty = Float64Type::get(getContext());
965+
if (typeInfo.first == f64Ty && typeInfo.second == 1) {
966+
if (getType() != f64Ty)
967+
return emitOpError("expected destination type to be f64");
968+
return success();
969+
}
970+
// Everything else is a struct
957971
Type dstType = LLVM::LLVMStructType::getLiteral(
958972
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
959973
if (getType() != dstType)

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,14 @@ func.func @invalid_range_equal_bounds() {
621621
%0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
622622
return
623623
}
624+
625+
// -----
626+
627+
// Test for correct return type check for wmma.load fragment a for f64
628+
llvm.func @nvvm_wmma_load_a_f64(%arg0: !llvm.ptr, %arg1 : i32) {
629+
// expected-error @below {{'nvvm.wmma.load' op expected destination type to be f64}}
630+
%0 = nvvm.wmma.load %arg0, %arg1
631+
{eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
632+
: (!llvm.ptr) -> !llvm.struct<(f64)>
633+
llvm.return
634+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,43 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
463463
llvm.return
464464
}
465465

466+
// CHECK-LABEL: @nvvm_wmma_load_a_f64
467+
llvm.func @nvvm_wmma_load_a_f64(%arg0: !llvm.ptr, %arg1 : i32) {
468+
// CHECK: call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}})
469+
%0 = nvvm.wmma.load %arg0, %arg1
470+
{eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
471+
: (!llvm.ptr) -> f64
472+
llvm.return
473+
}
474+
475+
// CHECK-LABEL: @nvvm_wmma_load_c_f64
476+
llvm.func @nvvm_wmma_load_c_f64(%arg0: !llvm.ptr, %arg1 : i32) {
477+
// CHECK: call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}})
478+
%0 = nvvm.wmma.load %arg0, %arg1
479+
{eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<c>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
480+
: (!llvm.ptr) -> !llvm.struct<(f64, f64)>
481+
llvm.return
482+
}
483+
484+
// CHECK-LABEL: @nvvm_wmma_mma_f64
485+
llvm.func @nvvm_wmma_mma_f64(%0 : f64, %1 : f64, %2 : f64, %3 : f64) {
486+
// CHECK: { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.col.f64(double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}})
487+
%r = nvvm.wmma.mma %0, %1, %2, %3
488+
{eltypeA = #nvvm.mma_type<f64>, eltypeB = #nvvm.mma_type<f64>, k = 4 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, m = 8 : i32, n = 8 : i32}
489+
: (f64, f64, f64, f64)
490+
-> !llvm.struct<(f64, f64)>
491+
llvm.return
492+
}
493+
494+
// CHECK-LABEL: @nvvm_wmma_store_d_f64
495+
llvm.func @nvvm_wmma_store_d_f64(%arg0: !llvm.ptr, %arg1 : i32, %arg2 : f64, %arg3 : f64) {
496+
// CHECK: call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p0(ptr %{{.*}}, double %{{.*}}, double %{{.*}}, i32 %{{.*}})
497+
nvvm.wmma.store %arg0, %arg1, %arg2, %arg3
498+
{eltype = #nvvm.mma_type<f64>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
499+
: !llvm.ptr, f64, f64
500+
llvm.return
501+
}
502+
466503
// CHECK-LABEL: @cp_async
467504
llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
468505
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})

0 commit comments

Comments
 (0)