@@ -272,6 +272,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
272
272
!eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
273
273
!eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
274
274
!eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
275
+ !eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
276
+ !eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
277
+ !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
278
+ !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
275
279
276
280
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
277
281
// All other supported geometries use the same fragment format for f32 and
@@ -298,6 +302,21 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
298
302
!eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
299
303
!eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
300
304
305
+ !eq(gft,"m16n8k4:a:f64") : !listsplat(llvm_double_ty, 2),
306
+ !eq(gft,"m16n8k4:b:f64") : [llvm_double_ty],
307
+ !eq(gft,"m16n8k4:c:f64") : !listsplat(llvm_double_ty, 4),
308
+ !eq(gft,"m16n8k4:d:f64") : !listsplat(llvm_double_ty, 4),
309
+
310
+ !eq(gft,"m16n8k8:a:f64") : !listsplat(llvm_double_ty, 4),
311
+ !eq(gft,"m16n8k8:b:f64") : !listsplat(llvm_double_ty, 2),
312
+ !eq(gft,"m16n8k8:c:f64") : !listsplat(llvm_double_ty, 4),
313
+ !eq(gft,"m16n8k8:d:f64") : !listsplat(llvm_double_ty, 4),
314
+
315
+ !eq(gft,"m16n8k16:a:f64") : !listsplat(llvm_double_ty, 8),
316
+ !eq(gft,"m16n8k16:b:f64") : !listsplat(llvm_double_ty, 4),
317
+ !eq(gft,"m16n8k16:c:f64") : !listsplat(llvm_double_ty, 4),
318
+ !eq(gft,"m16n8k16:d:f64") : !listsplat(llvm_double_ty, 4),
319
+
301
320
// wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
302
321
!eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
303
322
!eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
@@ -378,6 +397,26 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
378
397
!eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
379
398
!eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
380
399
400
+ // mma e4m3/e5m2 -> f16/f32 @ m16n8k16
401
+ !eq(gft,"m16n8k16:a:e4m3") : !listsplat(llvm_i32_ty, 2),
402
+ !eq(gft,"m16n8k16:a:e5m2") : !listsplat(llvm_i32_ty, 2),
403
+ !eq(gft,"m16n8k16:b:e4m3") : [llvm_i32_ty],
404
+ !eq(gft,"m16n8k16:b:e5m2") : [llvm_i32_ty],
405
+ // mma e4m3/e5m2/e3m2/e2m3/e2m1 -> f32 @ m16n8k32
406
+ !eq(gft,"m16n8k32:a:e4m3") : !listsplat(llvm_i32_ty, 4),
407
+ !eq(gft,"m16n8k32:a:e5m2") : !listsplat(llvm_i32_ty, 4),
408
+ !eq(gft,"m16n8k32:a:e3m2") : !listsplat(llvm_i32_ty, 4),
409
+ !eq(gft,"m16n8k32:a:e2m3") : !listsplat(llvm_i32_ty, 4),
410
+ !eq(gft,"m16n8k32:a:e2m1") : !listsplat(llvm_i32_ty, 4),
411
+ !eq(gft,"m16n8k32:b:e4m3") : !listsplat(llvm_i32_ty, 2),
412
+ !eq(gft,"m16n8k32:b:e5m2") : !listsplat(llvm_i32_ty, 2),
413
+ !eq(gft,"m16n8k32:b:e3m2") : !listsplat(llvm_i32_ty, 2),
414
+ !eq(gft,"m16n8k32:b:e2m3") : !listsplat(llvm_i32_ty, 2),
415
+ !eq(gft,"m16n8k32:b:e2m1") : !listsplat(llvm_i32_ty, 2),
416
+ // mma e2m1 -> f32 @m16n8k64
417
+ !eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
418
+ !eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 2),
419
+
381
420
// wmma/mma b1 -> s32 @ m8n8k128(b1)
382
421
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
383
422
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
@@ -468,14 +507,15 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
468
507
# !if(Satfinite, "_satfinite", "");
469
508
}
470
509
471
- class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
510
+ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
472
511
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
473
512
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
474
513
string record = "int_nvvm_mma"
475
514
# !subst(".", "_", b1op)
476
515
# "_" # A.geom
477
516
# "_" # ALayout
478
517
# "_" # BLayout
518
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
479
519
# !if(Satfinite, "_satfinite", "")
480
520
# signature;
481
521
}
@@ -601,14 +641,26 @@ class NVVM_MMA_OPS {
601
641
["m16n8k16", "m16n8k8"],
602
642
["bf16"], [], ["f32"], []>.ret;
603
643
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
604
- ["m8n8k4"],
644
+ ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16" ],
605
645
["f64"], [], ["f64"], []>.ret;
606
646
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
607
647
["m8n8k4", "m16n8k8", "m16n8k16"],
608
648
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
609
649
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
610
650
["m8n8k16", "m16n8k16", "m16n8k32"],
611
651
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
652
+ // m16n8k32 fp8 variants are intersected with f8f6f4 variants
653
+ // and processed there
654
+ list<list<WMMA_REGS>> fp8_mma_ops = MMA_OPS<
655
+ ["m16n8k16"],
656
+ ["e4m3", "e5m2"], ["e4m3", "e5m2"],
657
+ ["f16", "f32"], ["f16", "f32"]>.ret;
658
+ // it also contains e4m3/e5m2 from fp8 variants
659
+ list<list<WMMA_REGS>> f8f6f4_mma_ops = MMA_OPS<
660
+ ["m16n8k32"],
661
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
662
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
663
+ ["f16", "f32"], ["f16", "f32"]>.ret;
612
664
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
613
665
["m8n8k32", "m16n8k32", "m16n8k64"],
614
666
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
@@ -617,7 +669,8 @@ class NVVM_MMA_OPS {
617
669
["b1"], [], ["s32"], []>.ret;
618
670
list<list<WMMA_REGS>> all_mma_ops = !listconcat(
619
671
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
620
- fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
672
+ fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
673
+ int_mma_ops, subint_mma_ops, bit_mma_ops);
621
674
622
675
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
623
676
["m16n8k16", "m16n8k32"],
@@ -770,7 +823,8 @@ class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
770
823
// if NVVM_MMA_SUPPORTED<...>.ret then
771
824
// def : FOO<>; // The record will only be defined for supported ops.
772
825
//
773
- class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
826
+ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b,
827
+ string kind, int satf> {
774
828
// MMA ops check both layouts.
775
829
string layout = layout_a # ":" # layout_b;
776
830
string a_type = frags[0].ptx_elt_type;
@@ -805,10 +859,31 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
805
859
!or(!ne(a_type, b_type),
806
860
!ne(c_type, d_type))): false,
807
861
808
- // m16n8k8 requires C and D to be the same type.
809
- !and(!eq(geom, "m16n8k8"),
862
+ // m16n8k16/m16n8k32 requires C and D to be the same type
863
+ !and(!or(!eq(geom, "m16n8k16"),
864
+ !eq(geom, "m16n8k32")),
810
865
!ne(c_type, d_type)): false,
811
866
867
+ // Limit kind to valid types and geometries
868
+ !and(!ne(kind, ""),
869
+ !or(!ne(geom, "m16n8k32"),
870
+ !and(!ne(a_type, "e4m3"),
871
+ !ne(a_type, "e5m2"),
872
+ !ne(a_type, "e3m2"),
873
+ !ne(a_type, "e2m3"),
874
+ !ne(a_type, "e2m1")))): false,
875
+
876
+ // Limit m16n8k16/m16n8k32 with no kind to valid types
877
+ !and(!eq(kind, ""),
878
+ !or(!eq(geom, "m16n8k16"),
879
+ !eq(geom, "m16n8k32")),
880
+ !or(!eq(a_type, "e3m2"),
881
+ !eq(a_type, "e2m3"),
882
+ !eq(a_type, "e2m1"),
883
+ !eq(b_type, "e3m2"),
884
+ !eq(b_type, "e2m3"),
885
+ !eq(b_type, "e2m1"))): false,
886
+
812
887
// All other are OK.
813
888
true: true
814
889
);
@@ -882,9 +957,10 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
882
957
!eq(a_type, "tf32")),
883
958
!ne(a_type, b_type)): false,
884
959
885
- // m16n8k16 and m16n8k32 requires C and D to be the same type.
960
+ // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
886
961
!and(!or(!eq(geom, "m16n8k16"),
887
- !eq(geom, "m16n8k32")),
962
+ !eq(geom, "m16n8k32"),
963
+ !eq(geom, "m16n8k64")),
888
964
!ne(c_type, d_type)): false,
889
965
890
966
!and(!eq(kind, ""),
@@ -2252,10 +2328,12 @@ foreach layout_a = ["row", "col"] in {
2252
2328
foreach satf = [0, 1] in {
2253
2329
foreach op = NVVM_MMA_OPS.all_mma_ops in {
2254
2330
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
2255
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
2256
- def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
2257
- : NVVM_MMA<op[0], op[1], op[2], op[3]>;
2258
- }
2331
+ foreach kind = ["", "kind::f8f6f4"] in {
2332
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
2333
+ def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
2334
+ : NVVM_MMA<op[0], op[1], op[2], op[3]>;
2335
+ }
2336
+ } // kind
2259
2337
} // b1op
2260
2338
} // op
2261
2339
} // satf
0 commit comments