1010from itertools import product
1111from string import Template
1212
13+
1314class MMAType :
1415 def __init__ (self , ptx_type ):
1516 self .ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
176177 "m8n16:x1:b8x16.b4x16_p64" : 1 ,
177178 "m8n16:x2:b8x16.b4x16_p64" : 2 ,
178179 "m8n16:x4:b8x16.b4x16_p64" : 4 ,
180+ # stmatrix
181+ "m8n8:x1:b16" : 1 ,
182+ "m8n8:x2:b16" : 2 ,
183+ "m8n8:x4:b16" : 4 ,
184+ "m16n8:x1:b8" : 1 ,
185+ "m16n8:x2:b8" : 2 ,
186+ "m16n8:x4:b8" : 4 ,
179187 }.get (
180188 "%s:%s:%s" % (geom , frag , ptx_elt_type ),
181189 {
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
241249 ]
242250
243251
252+ def make_stmatrix_ops (geoms , frags , types ):
253+ return [
254+ MMAFrag (geom , frag , ptx_type )
255+ for (geom , frag , ptx_type ) in product (geoms , frags , types )
256+ ]
257+
258+
244259def get_wmma_ops ():
245260 return (
246261 make_mma_ops (["m16n16k8" ], ["tf32" ], [], ["f32" ], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
315330 )
316331
317332
333+ def get_stmatrix_ops ():
334+ return make_stmatrix_ops (["m8n8" ], ["x1" , "x2" , "x4" ], ["b16" ]) + make_stmatrix_ops (
335+ ["m16n8" ], ["x1" , "x2" , "x4" ], ["b8" ]
336+ )
337+
338+
318339def is_wmma_geom_supported (geom ):
319340 # geometries for FP and ints.
320341 if geom in ["m8n32k16" , "m32n8k16" ]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
360381 assert False # Unexpected geometry.
361382
362383
384+ def is_stmatrix_geom_supported (geom ):
385+ if geom in ["m8n8" ]:
386+ return ptx_version >= 78 and gpu_arch >= 90
387+ elif geom in ["m16n8" ]:
388+ return ptx_version >= 86 and gpu_arch >= 100 and aa
389+ assert False # Unexpected geometry.
390+
391+
363392def is_ldmatrix_trans_supported (geom , trans ):
364393 if geom in ["m8n8" ]:
365394 return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
369398 return trans == ""
370399 assert False # Unexpected geometry.
371400
401+
402+ def is_stmatrix_trans_supported (geom , trans ):
403+ if geom in ["m8n8" ]:
404+ return True
405+ elif geom in ["m16n8" ]:
406+ return trans == ".trans"
407+ assert False # Unexpected geometry.
408+
409+
372410def is_type_supported (ptx_type ):
373411 if ptx_type in ["s8" , "u8" , "s32" ]:
374412 return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
463501 return frag .frag in ["x1" , "x2" , "x4" ]
464502
465503
504+ def is_stmatrix_variant_supported (frag , trans ):
505+ if not (
506+ is_type_supported (frag .mma_type .ptx_type )
507+ and is_stmatrix_geom_supported (frag .geom )
508+ and is_stmatrix_trans_supported (frag .geom , trans )
509+ ):
510+ return False
511+ return frag .frag in ["x1" , "x2" , "x4" ]
512+
513+
466514def make_wmma_slice_ty (frag ):
467515 return [frag .mma_type .llvm_type ] * frag .nregs
468516
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
716764
717765 return generated_items
718766
767+ def gen_stmatrix_tests ():
768+ stmatrix_template = """
769+ declare void @${intrinsic}(i8 ${as}* %dst, ${args});
770+
771+ ; CHECK-LABEL: .func {{.*}}test_${function}(
772+ define void @test_${function}(i8 ${as}* %dst, ${args}) {
773+ ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
774+ ; CHECK: {${check_args}}
775+ call void @${intrinsic}(i8${as}* %dst, ${args});
776+ ret void
777+ }
778+
779+ ; CHECK-LABEL: .func{{.*}}test_${function}_o(
780+ define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
781+ ; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
782+ ; CHECK: {${check_args}}
783+ %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
784+ call void @${intrinsic}(i8 ${as}* %dst1, ${args});
785+ ret void
786+ }
787+ """
788+ intrinsic_template = (
789+ "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
790+ )
791+ instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
792+ )
793+ generated_items = []
794+
795+ for frag , space , trans in product (get_stmatrix_ops (),
796+ ["" , ".shared" ],
797+ ["" , ".trans" ],
798+ ):
799+ if not is_stmatrix_variant_supported (frag , trans ):
800+ continue
801+
802+ params = {
803+ "frag" : frag .frag ,
804+ "space" : space ,"trans" : trans ,
805+ "itype" : frag .mma_type .ptx_type ,
806+ "pspace" : get_pspace (space ),
807+ "as" : "addrspace(%d)" % get_aspace (space ),
808+ "geom" : frag .geom ,
809+ }
810+
811+ test_params = params
812+ test_params ["intrinsic" ] = Template (intrinsic_template ).substitute (params )
813+ test_params ["function" ] = test_params ["intrinsic" ].replace ("." , "_" )
814+ test_params ["instruction" ] = Template (instruction_template ).substitute (params )
815+ test_params ["args" ] = make_wmma_slice_args (frag )
816+ test_params ["check_args" ] = check_pattern (frag )
817+
818+ print (Template (stmatrix_template ).substitute (test_params ))
819+ generated_items .append ((test_params ["intrinsic" ], test_params ["instruction" ]))
820+
821+ return generated_items
719822
720823def mma_signature (op ):
721824 if op .a .mma_type .ptx_type == "f16" :
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
893996; NOALTFLOAT-NOT: .{{bf16|tf32}}
894997; NODOUBLE-NOT: .f64
895998; NOLDMATRIX-NOT: ldmatrix.sync.aligned
999+ ; NOSTMATRIX-NOT: stmatrix.sync.aligned
8961000
8971001; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
8981002; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
9941098; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
9951099; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
9961100
1101+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
1102+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
1103+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
1104+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
1105+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
1106+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
1107+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
1108+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
1109+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
1110+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
1111+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
1112+ ; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
1113+
1114+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
1115+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
1116+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
1117+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
1118+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
1119+ ; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
1120+
9971121; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
9981122; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
9991123; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
10391163 items = gen_wmma_load_tests ()
10401164 items += gen_wmma_store_tests ()
10411165 items += gen_ldmatrix_tests ()
1166+ items += gen_stmatrix_tests ()
10421167 items += gen_wmma_mma_tests ()
10431168 items += gen_mma_tests ()
10441169 gen_check_unsupported_ops (items )
0 commit comments