@@ -456,7 +456,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
456456}
457457
458458class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
459- string intr = "llvm.nvvm.wmma."
459+ string intr_name = "llvm.nvvm.wmma."
460460 # Frag.geom
461461 # "." # Op
462462 # "." # Frag.frag
@@ -467,7 +467,7 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
467467 // TODO(tra): record name should ideally use the same field order as the intrinsic.
468468 // E.g. string record = !subst("llvm", "int",
469469 // !subst(".", "_", llvm));
470- string record = "int_nvvm_wmma_"
470+ string record_name = "int_nvvm_wmma_"
471471 # Frag.geom
472472 # "_" # Op
473473 # "_" # Frag.frag
@@ -496,7 +496,7 @@ class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
496496class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, string b1op,
497497 WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
498498 string signature = MMA_SIGNATURE<A, B, C, D>.ret;
499- string record = "int_nvvm_wmma_"
499+ string record_name = "int_nvvm_wmma_"
500500 # A.geom
501501 # "_mma"
502502 # !subst(".", "_", b1op)
@@ -510,7 +510,7 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
510510class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
511511 WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
512512 string signature = MMA_SIGNATURE<A, B, C, D>.ret;
513- string record = "int_nvvm_mma"
513+ string record_name = "int_nvvm_mma"
514514 # !subst(".", "_", b1op)
515515 # "_" # A.geom
516516 # "_" # ALayout
@@ -524,7 +524,7 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
524524 WMMA_REGS A, WMMA_REGS B,
525525 WMMA_REGS C, WMMA_REGS D> {
526526 string signature = MMA_SIGNATURE<A, B, C, D>.ret;
527- string record = "int_nvvm_mma"
527+ string record_name = "int_nvvm_mma"
528528 # "_" # !subst("::", "_", Metadata)
529529 # "_" # A.geom
530530 # "_row_col"
@@ -533,26 +533,37 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
533533 # signature;
534534}
535535
536+ // Helper class that takes an intrinsic name and construct a record name.
537+ // Additionally, sets `intr_name` to be non-empty if the default name assigned
538+ // to this intrinsic will not match the name given.
539+ class IntrinsicName<string name> {
540+ string record_name = !subst(".", "_",
541+ !subst("llvm.", "int_", name));
542+ // Use explicit intrinsic name if it has an _ in it, else rely on LLVM
543+ // assigned default name.
544+ string intr_name = !if(!ne(!find(name, "_"), -1), name, "");
545+ }
546+
536547class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
537- string intr = "llvm.nvvm.ldmatrix.sync.aligned"
548+ defvar name = "llvm.nvvm.ldmatrix.sync.aligned"
538549 # "." # Frag.geom
539550 # "." # Frag.frag
540551 # !if(Trans, ".trans", "")
541552 # "." # Frag.ptx_elt_type
542553 ;
543- string record = !subst(".", "_",
544- !subst("llvm.", "int_", intr)) ;
554+ string intr_name = IntrinsicName<name>.intr_name;
555+ string record_name = IntrinsicName<name>.record_name ;
545556}
546557
547558class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
548- string intr = "llvm.nvvm.stmatrix.sync.aligned"
559+ defvar name = "llvm.nvvm.stmatrix.sync.aligned"
549560 # "." # Frag.geom
550561 # "." # Frag.frag
551562 # !if(Trans, ".trans", "")
552563 # "." # Frag.ptx_elt_type
553564 ;
554- string record = !subst(".", "_",
555- !subst("llvm.", "int_", intr)) ;
565+ string intr_name = IntrinsicName<name>.intr_name;
566+ string record_name = IntrinsicName<name>.record_name ;
556567}
557568
558569// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
@@ -1042,45 +1053,49 @@ class NVVM_TCGEN05_MMA_BASE<string Space, bit Sp> {
10421053class NVVM_TCGEN05_MMA<bit Sp, string Space,
10431054 bit AShift, bit ScaleInputD>:
10441055 NVVM_TCGEN05_MMA_BASE<Space, Sp> {
1045- string intr = "llvm.nvvm.tcgen05.mma"
1056+ string name = "llvm.nvvm.tcgen05.mma"
10461057 # !if(!eq(Sp, 1), ".sp", "")
10471058 # "." # Space
10481059 # !if(!eq(ScaleInputD, 1), ".scale_d", "")
10491060 # !if(!eq(AShift, 1), ".ashift", "");
1050- string record = !subst(".", "_", !subst("llvm.", "int_", intr));
1061+ string intr_name = IntrinsicName<name>.intr_name;
1062+ string record_name = IntrinsicName<name>.record_name;
10511063}
10521064
10531065class NVVM_TCGEN05_MMA_BLOCKSCALE<bit Sp, string Space,
10541066 string Kind, string ScaleVecSize>:
10551067 NVVM_TCGEN05_MMA_BASE<Space, Sp> {
1056- string intr = "llvm.nvvm.tcgen05.mma"
1068+ string name = "llvm.nvvm.tcgen05.mma"
10571069 # !if(!eq(Sp, 1), ".sp", "")
10581070 # "." # Space
10591071 # "." # Kind
10601072 # ".block_scale" # ScaleVecSize;
1061- string record = !subst(".", "_", !subst("llvm.", "int_", intr));
1073+ string intr_name = IntrinsicName<name>.intr_name;
1074+ string record_name = IntrinsicName<name>.record_name;
10621075}
10631076
10641077class NVVM_TCGEN05_MMA_WS<bit Sp, string Space, bit ZeroColMask>:
10651078 NVVM_TCGEN05_MMA_BASE<Space, Sp> {
1066- string intr = "llvm.nvvm.tcgen05.mma.ws"
1079+ string name = "llvm.nvvm.tcgen05.mma.ws"
10671080 # !if(!eq(Sp, 1), ".sp", "")
10681081 # "." # Space
10691082 # !if(!eq(ZeroColMask, 1), ".zero_col_mask", "");
1070- string record = !subst(".", "_", !subst("llvm.", "int_", intr));
1083+ string intr_name = IntrinsicName<name>.intr_name;
1084+ string record_name = IntrinsicName<name>.record_name;
10711085}
10721086
10731087class NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<bit Sp, string Space,
10741088 int CtaGroup, bit AShift,
10751089 bit ScaleInputD>:
10761090 NVVM_TCGEN05_MMA_BASE<Space, Sp> {
1077- string intr = "llvm.nvvm.tcgen05.mma"
1091+ string name = "llvm.nvvm.tcgen05.mma"
10781092 # !if(!eq(Sp, 1), ".sp", "")
10791093 # "." # Space
10801094 # !if(!eq(ScaleInputD, 1), ".scale_d", "")
10811095 # ".disable_output_lane.cg" # CtaGroup
10821096 # !if(!eq(AShift, 1), ".ashift", "");
1083- string record = !subst(".", "_", !subst("llvm.", "int_", intr));
1097+ string intr_name = IntrinsicName<name>.intr_name;
1098+ string record_name = IntrinsicName<name>.record_name;
10841099}
10851100
10861101class NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<string Kind, string ScaleVecSize> {
@@ -2273,7 +2288,7 @@ class NVVM_WMMA_LD<WMMA_REGS Frag, string Layout, int WithStride>
22732288 : Intrinsic<Frag.regs,
22742289 !if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
22752290 [IntrWillReturn, IntrReadMem, IntrArgMemOnly, IntrNoCallback, ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
2276- WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr >;
2291+ WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr_name >;
22772292
22782293// WMMA.STORE.D
22792294class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
@@ -2283,18 +2298,18 @@ class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
22832298 Frag.regs,
22842299 !if(WithStride, [llvm_i32_ty], [])),
22852300 [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
2286- WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr >;
2301+ WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr_name >;
22872302
22882303// Create all load/store variants
22892304foreach layout = ["row", "col"] in {
22902305 foreach stride = [0, 1] in {
22912306 foreach frag = NVVM_MMA_OPS.all_ld_ops in
22922307 if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
2293- def WMMA_NAME_LDST<"load", frag, layout, stride>.record
2308+ def WMMA_NAME_LDST<"load", frag, layout, stride>.record_name
22942309 : NVVM_WMMA_LD<frag, layout, stride>;
22952310 foreach frag = NVVM_MMA_OPS.all_st_ops in
22962311 if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
2297- def WMMA_NAME_LDST<"store", frag, layout, stride>.record
2312+ def WMMA_NAME_LDST<"store", frag, layout, stride>.record_name
22982313 : NVVM_WMMA_ST<frag, layout, stride>;
22992314 }
23002315}
@@ -2313,7 +2328,7 @@ foreach layout_a = ["row", "col"] in {
23132328 foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
23142329 if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
23152330 def WMMA_NAME<layout_a, layout_b, satf, rnd, b1op,
2316- op[0], op[1], op[2], op[3]>.record
2331+ op[0], op[1], op[2], op[3]>.record_name
23172332 : NVVM_MMA<op[0], op[1], op[2], op[3]>;
23182333 }
23192334 } // b1op
@@ -2330,7 +2345,7 @@ foreach layout_a = ["row", "col"] in {
23302345 foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
23312346 foreach kind = ["", "kind::f8f6f4"] in {
23322347 if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
2333- def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
2348+ def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record_name
23342349 : NVVM_MMA<op[0], op[1], op[2], op[3]>;
23352350 }
23362351 } // kind
@@ -2379,7 +2394,7 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
23792394 foreach op = NVVM_MMA_OPS.all_mma_sp_ops in {
23802395 if NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret then {
23812396 def MMA_SP_NAME<metadata, kind, satf,
2382- op[0], op[1], op[2], op[3]>.record
2397+ op[0], op[1], op[2], op[3]>.record_name
23832398 : NVVM_MMA_SP<op[0], op[1], op[2], op[3]>;
23842399 }
23852400 } // op
@@ -2392,12 +2407,12 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
23922407 : Intrinsic<Frag.regs, [llvm_anyptr_ty],
23932408 [IntrReadMem, IntrArgMemOnly, IntrNoCallback, ReadOnly<ArgIndex<0>>,
23942409 NoCapture<ArgIndex<0>>],
2395- LDMATRIX_NAME<Frag, Transposed>.intr >;
2410+ LDMATRIX_NAME<Frag, Transposed>.intr_name >;
23962411
23972412foreach transposed = [0, 1] in {
23982413 foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
23992414 if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
2400- def LDMATRIX_NAME<frag, transposed>.record
2415+ def LDMATRIX_NAME<frag, transposed>.record_name
24012416 : NVVM_LDMATRIX<frag, transposed>;
24022417 }
24032418 }
@@ -2409,12 +2424,12 @@ class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
24092424 !listconcat([llvm_anyptr_ty], Frag.regs),
24102425 [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
24112426 WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
2412- STMATRIX_NAME<Frag, Transposed>.intr >;
2427+ STMATRIX_NAME<Frag, Transposed>.intr_name >;
24132428
24142429foreach transposed = [0, 1] in {
24152430 foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
24162431 if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
2417- def STMATRIX_NAME<frag, transposed>.record
2432+ def STMATRIX_NAME<frag, transposed>.record_name
24182433 : NVVM_STMATRIX<frag, transposed>;
24192434 }
24202435 }
@@ -2767,14 +2782,15 @@ foreach cta_group = ["cg1", "cg2"] in {
27672782 "64x128b_warpx2_02_13",
27682783 "64x128b_warpx2_01_23",
27692784 "32x128b_warpx4"] in {
2770- defvar intr_suffix = StrJoin<"_", [shape, src_fmt, cta_group]>.ret;
2771- defvar name_suffix = StrJoin<".", [shape, src_fmt, cta_group]>.ret;
2785+ defvar name = "llvm.nvvm.tcgen05.cp." #
2786+ StrJoin<".", [shape, src_fmt, cta_group]>.ret;
27722787
2773- def int_nvvm_tcgen05_cp_ # intr_suffix : Intrinsic<[],
2788+ defvar intrinsic_name = IntrinsicName<name>;
2789+ def intrinsic_name.record_name : Intrinsic<[],
27742790 [llvm_tmem_ptr_ty, // tmem_addr
27752791 llvm_i64_ty], // smem descriptor
27762792 [IntrConvergent, IntrInaccessibleMemOrArgMemOnly, NoCapture<ArgIndex<0>>],
2777- "llvm.nvvm.tcgen05.cp." # name_suffix >;
2793+ intrinsic_name.intr_name >;
27782794 }
27792795 }
27802796}
@@ -2881,9 +2897,9 @@ foreach sp = [0, 1] in {
28812897 ]
28822898 );
28832899
2884- def mma.record :
2900+ def mma.record_name :
28852901 DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2886- mma.intr >;
2902+ mma.intr_name >;
28872903 }
28882904 }
28892905 }
@@ -2918,8 +2934,8 @@ foreach sp = [0, 1] in {
29182934 Range<ArgIndex<!add(nargs, 1)>, 0, !if(!eq(ashift, 1), 2, 4)>]
29192935 );
29202936
2921- def mma.record : DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties ,
2922- mma.intr >;
2937+ def mma.record_name : DefaultAttrsIntrinsicFlags<[], args, flags,
2938+ intrinsic_properties, mma.intr_name >;
29232939 } // ashift
29242940 } // scale_d
29252941 } // cta_group
@@ -2944,11 +2960,11 @@ foreach sp = [0, 1] in {
29442960 defvar collector_usage = ArgIndex<!add(nargs, 1)>;
29452961
29462962 if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
2947- def mma.record : DefaultAttrsIntrinsicFlags<[], args, flags,
2963+ def mma.record_name : DefaultAttrsIntrinsicFlags<[], args, flags,
29482964 !listconcat(mma.common_intr_props,
29492965 [Range<cta_group, 1, 3>,
29502966 Range<collector_usage, 0, 4>]),
2951- mma.intr >;
2967+ mma.intr_name >;
29522968 }
29532969 }
29542970 }
@@ -2977,9 +2993,9 @@ foreach sp = [0, 1] in {
29772993 Range<ArgIndex<!add(nargs, 2)>, 0, 4>]
29782994 );
29792995
2980- def mma.record :
2996+ def mma.record_name :
29812997 DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2982- mma.intr >;
2998+ mma.intr_name >;
29832999 }
29843000 }
29853001}
0 commit comments