Skip to content

Commit 3bae5c9

Browse files
kvederniaokblast
authored andcommitted
[NVPTX] Added more MMA intrinsics for F8F6F4 and FP64 types. (#156040)
This change adds more MMA intrinsics for F8F6F4 and FP64 types. The implementation is based on [PTX ISA version 9.0](https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma). New restrictions were added for dtype/ctype combinations for MMA and sparse MMA intrinsics. MLIR restrictions for dtype/ctype MMA intrinsics were aligned with NVVM IR.
1 parent 8a40f0e commit 3bae5c9

File tree

9 files changed

+252
-85
lines changed

9 files changed

+252
-85
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
272272
!eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
273273
!eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
274274
!eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
275+
!eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
276+
!eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
277+
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
278+
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
275279

276280
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
277281
// All other supported geometries use the same fragment format for f32 and
@@ -298,6 +302,21 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
298302
!eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
299303
!eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
300304

305+
!eq(gft,"m16n8k4:a:f64") : !listsplat(llvm_double_ty, 2),
306+
!eq(gft,"m16n8k4:b:f64") : [llvm_double_ty],
307+
!eq(gft,"m16n8k4:c:f64") : !listsplat(llvm_double_ty, 4),
308+
!eq(gft,"m16n8k4:d:f64") : !listsplat(llvm_double_ty, 4),
309+
310+
!eq(gft,"m16n8k8:a:f64") : !listsplat(llvm_double_ty, 4),
311+
!eq(gft,"m16n8k8:b:f64") : !listsplat(llvm_double_ty, 2),
312+
!eq(gft,"m16n8k8:c:f64") : !listsplat(llvm_double_ty, 4),
313+
!eq(gft,"m16n8k8:d:f64") : !listsplat(llvm_double_ty, 4),
314+
315+
!eq(gft,"m16n8k16:a:f64") : !listsplat(llvm_double_ty, 8),
316+
!eq(gft,"m16n8k16:b:f64") : !listsplat(llvm_double_ty, 4),
317+
!eq(gft,"m16n8k16:c:f64") : !listsplat(llvm_double_ty, 4),
318+
!eq(gft,"m16n8k16:d:f64") : !listsplat(llvm_double_ty, 4),
319+
301320
// wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
302321
!eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
303322
!eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
@@ -378,6 +397,26 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
378397
!eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
379398
!eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
380399

400+
// mma e4m3/e5m2 -> f16/f32 @ m16n8k16
401+
!eq(gft,"m16n8k16:a:e4m3") : !listsplat(llvm_i32_ty, 2),
402+
!eq(gft,"m16n8k16:a:e5m2") : !listsplat(llvm_i32_ty, 2),
403+
!eq(gft,"m16n8k16:b:e4m3") : [llvm_i32_ty],
404+
!eq(gft,"m16n8k16:b:e5m2") : [llvm_i32_ty],
405+
// mma e4m3/e5m2/e3m2/e2m3/e2m1 -> f32 @ m16n8k32
406+
!eq(gft,"m16n8k32:a:e4m3") : !listsplat(llvm_i32_ty, 4),
407+
!eq(gft,"m16n8k32:a:e5m2") : !listsplat(llvm_i32_ty, 4),
408+
!eq(gft,"m16n8k32:a:e3m2") : !listsplat(llvm_i32_ty, 4),
409+
!eq(gft,"m16n8k32:a:e2m3") : !listsplat(llvm_i32_ty, 4),
410+
!eq(gft,"m16n8k32:a:e2m1") : !listsplat(llvm_i32_ty, 4),
411+
!eq(gft,"m16n8k32:b:e4m3") : !listsplat(llvm_i32_ty, 2),
412+
!eq(gft,"m16n8k32:b:e5m2") : !listsplat(llvm_i32_ty, 2),
413+
!eq(gft,"m16n8k32:b:e3m2") : !listsplat(llvm_i32_ty, 2),
414+
!eq(gft,"m16n8k32:b:e2m3") : !listsplat(llvm_i32_ty, 2),
415+
!eq(gft,"m16n8k32:b:e2m1") : !listsplat(llvm_i32_ty, 2),
416+
// mma e2m1 -> f32 @m16n8k64
417+
!eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
418+
!eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 2),
419+
381420
// wmma/mma b1 -> s32 @ m8n8k128(b1)
382421
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
383422
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
@@ -468,14 +507,15 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
468507
# !if(Satfinite, "_satfinite", "");
469508
}
470509

471-
class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
510+
class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
472511
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
473512
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
474513
string record = "int_nvvm_mma"
475514
# !subst(".", "_", b1op)
476515
# "_" # A.geom
477516
# "_" # ALayout
478517
# "_" # BLayout
518+
# !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
479519
# !if(Satfinite, "_satfinite", "")
480520
# signature;
481521
}
@@ -601,14 +641,26 @@ class NVVM_MMA_OPS {
601641
["m16n8k16", "m16n8k8"],
602642
["bf16"], [], ["f32"], []>.ret;
603643
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
604-
["m8n8k4"],
644+
["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"],
605645
["f64"], [], ["f64"], []>.ret;
606646
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
607647
["m8n8k4", "m16n8k8", "m16n8k16"],
608648
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
609649
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
610650
["m8n8k16", "m16n8k16", "m16n8k32"],
611651
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
652+
// m16n8k32 fp8 variants are intersected with f8f6f4 variants
653+
// and processed there
654+
list<list<WMMA_REGS>> fp8_mma_ops = MMA_OPS<
655+
["m16n8k16"],
656+
["e4m3", "e5m2"], ["e4m3", "e5m2"],
657+
["f16", "f32"], ["f16", "f32"]>.ret;
658+
// it also contains e4m3/e5m2 from fp8 variants
659+
list<list<WMMA_REGS>> f8f6f4_mma_ops = MMA_OPS<
660+
["m16n8k32"],
661+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
662+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
663+
["f16", "f32"], ["f16", "f32"]>.ret;
612664
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
613665
["m8n8k32", "m16n8k32", "m16n8k64"],
614666
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
@@ -617,7 +669,8 @@ class NVVM_MMA_OPS {
617669
["b1"], [], ["s32"], []>.ret;
618670
list<list<WMMA_REGS>> all_mma_ops = !listconcat(
619671
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
620-
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
672+
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
673+
int_mma_ops, subint_mma_ops, bit_mma_ops);
621674

622675
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
623676
["m16n8k16", "m16n8k32"],
@@ -770,7 +823,8 @@ class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
770823
// if NVVM_MMA_SUPPORTED<...>.ret then
771824
// def : FOO<>; // The record will only be defined for supported ops.
772825
//
773-
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
826+
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b,
827+
string kind, int satf> {
774828
// MMA ops check both layouts.
775829
string layout = layout_a # ":" # layout_b;
776830
string a_type = frags[0].ptx_elt_type;
@@ -805,10 +859,31 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
805859
!or(!ne(a_type, b_type),
806860
!ne(c_type, d_type))): false,
807861

808-
// m16n8k8 requires C and D to be the same type.
809-
!and(!eq(geom, "m16n8k8"),
862+
// m16n8k16/m16n8k32 requires C and D to be the same type
863+
!and(!or(!eq(geom, "m16n8k16"),
864+
!eq(geom, "m16n8k32")),
810865
!ne(c_type, d_type)): false,
811866

867+
// Limit kind to valid types and geometries
868+
!and(!ne(kind, ""),
869+
!or(!ne(geom, "m16n8k32"),
870+
!and(!ne(a_type, "e4m3"),
871+
!ne(a_type, "e5m2"),
872+
!ne(a_type, "e3m2"),
873+
!ne(a_type, "e2m3"),
874+
!ne(a_type, "e2m1")))): false,
875+
876+
// Limit m16n8k16/m16n8k32 with no kind to valid types
877+
!and(!eq(kind, ""),
878+
!or(!eq(geom, "m16n8k16"),
879+
!eq(geom, "m16n8k32")),
880+
!or(!eq(a_type, "e3m2"),
881+
!eq(a_type, "e2m3"),
882+
!eq(a_type, "e2m1"),
883+
!eq(b_type, "e3m2"),
884+
!eq(b_type, "e2m3"),
885+
!eq(b_type, "e2m1"))): false,
886+
812887
// All other are OK.
813888
true: true
814889
);
@@ -882,9 +957,10 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
882957
!eq(a_type, "tf32")),
883958
!ne(a_type, b_type)): false,
884959

885-
// m16n8k16 and m16n8k32 requires C and D to be the same type.
960+
// m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
886961
!and(!or(!eq(geom, "m16n8k16"),
887-
!eq(geom, "m16n8k32")),
962+
!eq(geom, "m16n8k32"),
963+
!eq(geom, "m16n8k64")),
888964
!ne(c_type, d_type)): false,
889965

890966
!and(!eq(kind, ""),
@@ -2252,10 +2328,12 @@ foreach layout_a = ["row", "col"] in {
22522328
foreach satf = [0, 1] in {
22532329
foreach op = NVVM_MMA_OPS.all_mma_ops in {
22542330
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
2255-
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
2256-
def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
2257-
: NVVM_MMA<op[0], op[1], op[2], op[3]>;
2258-
}
2331+
foreach kind = ["", "kind::f8f6f4"] in {
2332+
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
2333+
def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
2334+
: NVVM_MMA<op[0], op[1], op[2], op[3]>;
2335+
}
2336+
} // kind
22592337
} // b1op
22602338
} // op
22612339
} // satf

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4528,6 +4528,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
45284528
!eq(ptx_elt_type, "e2m1"),
45294529
!ne(kind, "")) : [hasSM120a, hasPTX<87>],
45304530

4531+
!and(!or(!eq(ptx_elt_type,"e4m3"),
4532+
!eq(ptx_elt_type,"e5m2")),
4533+
!eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],
4534+
45314535
!or(!eq(ptx_elt_type, "e4m3"),
45324536
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
45334537

@@ -4543,6 +4547,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
45434547
!and(!eq(geom, "m8n8k4"),
45444548
!eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],
45454549

4550+
!and(!or(!eq(geom, "m16n8k4"),
4551+
!eq(geom, "m16n8k8"),
4552+
!eq(geom, "m16n8k16")),
4553+
!eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],
4554+
45464555
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
45474556
!and(!or(!eq(geom, "m8n32k16"),
45484557
!eq(geom, "m32n8k16")),
@@ -4827,8 +4836,8 @@ defset list<WMMA_INSTR> WMMAs = {
48274836
// MMA
48284837
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
48294838
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
4830-
string ALayout, string BLayout, int Satfinite, string b1op>
4831-
: WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
4839+
string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
4840+
: WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
48324841
[FragA.Ins, FragB.Ins, FragC.Ins]>,
48334842
// Requires does not seem to have effect on Instruction w/o Patterns.
48344843
// We set it here anyways and propagate to the Pat<> we construct below.
@@ -4843,6 +4852,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
48434852
# FragA.geom
48444853
# "." # ALayout
48454854
# "." # BLayout
4855+
# !if(!ne(Kind, ""), "." # Kind, "")
48464856
# !if(Satfinite, ".satfinite", "")
48474857
# TypeList
48484858
# b1op # "\n\t\t"
@@ -4859,13 +4869,15 @@ defset list<WMMA_INSTR> MMAs = {
48594869
foreach satf = [0, 1] in {
48604870
foreach op = NVVM_MMA_OPS.all_mma_ops in {
48614871
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
4862-
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
4863-
def : MMA<WMMA_REGINFO<op[0], "mma">,
4864-
WMMA_REGINFO<op[1], "mma">,
4865-
WMMA_REGINFO<op[2], "mma">,
4866-
WMMA_REGINFO<op[3], "mma">,
4867-
layout_a, layout_b, satf, b1op>;
4868-
}
4872+
foreach kind = ["", "kind::f8f6f4"] in {
4873+
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
4874+
def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
4875+
WMMA_REGINFO<op[1], "mma", "", kind>,
4876+
WMMA_REGINFO<op[2], "mma", "", kind>,
4877+
WMMA_REGINFO<op[3], "mma", "", kind>,
4878+
layout_a, layout_b, satf, b1op, kind>;
4879+
}
4880+
} // kind
48694881
} // b1op
48704882
} // op
48714883
} // satf

llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# RUN: %python %s --ptx=87 --gpu-arch=120 --aa > %t-ptx87-sm_120a.ll
33
# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
44
# RUN: | FileCheck %t-ptx87-sm_120a.ll
5-
# RUN: %if ptxas-12.7 %{ \
5+
# RUN: %if ptxas-sm_120a && ptxas-isa-8.7 %{ \
66
# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
77
# RUN: | %ptxas-verify -arch=sm_120a \
88
# RUN: %}

0 commit comments

Comments
 (0)