@@ -599,75 +599,15 @@ class TMA_IM2COL_UTIL<int dim, string mode> {
599599 string base_str = !interleave(!foreach(i, !range(offsets), "$im2col" # i), ", ");
600600}
601601
602- // From Global to Shared memory (G2S)
603- class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
604- string prefix = "cp.async.bulk.tensor";
605- string dir = "shared::cluster.global";
606- string completion = "mbarrier::complete_tx::bytes";
607- string inst_name = prefix
608- # "." # dim # "d"
609- # "." # dir
610- # "." # mode
611- # "." # completion
612- # !if(mc, ".multicast::cluster", "")
613- # !if(ch, ".L2::cache_hint", "");
614- string intr_name = "CP_ASYNC_BULK_TENSOR_G2S_"
615- # dim # "D"
616- # !if(is_shared32, "_SHARED32", "")
617- # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
618- }
619-
620602def CTAGroupFlags : Operand<i32> {
621603 let PrintMethod = "printCTAGroup";
622604}
623605
624- multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> {
625- defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
626- defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
627- defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
628- defvar rc = !if(is_shared32, B32, B64);
629-
630- defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
631- defvar im2col_dag = !if(!eq(mode, "im2col"),
632- !dag(ins, !listsplat(B16, num_im2col), !foreach(i, !range(num_im2col), "im2col" # i)),
633- (ins));
634- defvar im2col_str = !interleave(!foreach(i, !range(num_im2col), "$im2col" # i), ", ");
635- defvar im2col_asm_str = ", {{" # im2col_str # "}}";
636-
637- defvar asm_str = !if(!eq(mode, "im2col"),
638- !strconcat(asm_str_default, im2col_asm_str), asm_str_default);
606+ def tma_cta_group_imm0 : TImmLeaf<i32, [{return Imm == 0;}]>;
607+ def tma_cta_group_imm_any : TImmLeaf<i32, [{return Imm >= 0;}]>;
639608
640- def "" : NVPTXInst<(outs),
641- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)),
642- !strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";")>,
643- Requires<[hasPTX<80>, hasSM<90>]>;
644- def _MC : NVPTXInst<(outs),
645- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
646- (ins B16:$mc, CTAGroupFlags:$cg)),
647- !strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;")>,
648- Requires<[hasPTX<80>, hasSM<90>]>;
649- def _CH : NVPTXInst<(outs),
650- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
651- (ins B64:$ch, CTAGroupFlags:$cg)),
652- !strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;")>,
653- Requires<[hasPTX<80>, hasSM<90>]>;
654- def _MC_CH : NVPTXInst<(outs),
655- !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
656- (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg)),
657- !strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;")>,
658- Requires<[hasPTX<80>, hasSM<90>]>;
659- }
660-
661- foreach dim = [1, 2, 3, 4, 5] in {
662- foreach shared32 = [true, false] in {
663- foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
664- defm G2S_STRINGS<dim, mode, 0, 0, shared32>.intr_name :
665- CP_ASYNC_BULK_TENSOR_G2S_INTR<dim, shared32, mode>;
666- }
667- }
668- }
669-
670- multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> {
609+ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred,
610+ TImmLeaf cta_group_type = tma_cta_group_imm_any> {
671611 defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
672612 defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
673613 defvar asm_str_base = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
@@ -697,10 +637,10 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>
697637 !setdagop(dims_dag, intr),
698638 !setdagop(im2col_dag, intr),
699639 (intr B16:$mc, B64:$ch));
700- defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, timm :$cg));
701- defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, timm :$cg));
702- defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, timm :$cg));
703- defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, timm :$cg));
640+ defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, cta_group_type :$cg));
641+ defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, cta_group_type :$cg));
642+ defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, cta_group_type :$cg));
643+ defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, cta_group_type :$cg));
704644
705645 def "" : NVPTXInst<(outs), ins_dag,
706646 inst_name # asm_str # ";",
@@ -719,14 +659,30 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>
719659 [intr_dag_with_mc_ch]>,
720660 Requires<pred>;
721661}
662+
663+ foreach dim = 1...5 in {
664+ defm TMA_G2S_TILE_CG0_ # dim # "D"
665+ : TMA_TENSOR_G2S_INTR<dim, "tile", [hasPTX<80>, hasSM<90>],
666+ tma_cta_group_imm0>;
667+ defm TMA_G2S_TILE_ # dim # "D"
668+ : TMA_TENSOR_G2S_INTR<dim, "tile",
669+ [callSubtarget<"hasTMABlackwellSupport">]>;
670+ }
722671foreach dim = 3...5 in {
672+ defm TMA_G2S_IM2COL_CG0_ # dim # "D"
673+ : TMA_TENSOR_G2S_INTR<dim, "im2col", [hasPTX<80>, hasSM<90>],
674+ tma_cta_group_imm0>;
675+ defm TMA_G2S_IM2COL_ # dim # "D"
676+ : TMA_TENSOR_G2S_INTR<dim, "im2col",
677+ [callSubtarget<"hasTMABlackwellSupport">]>;
723678 foreach mode = ["im2col_w", "im2col_w_128"] in {
724679 defm TMA_G2S_ # !toupper(mode) # "_" # dim # "D"
725- : TMA_TENSOR_G2S_INTR<dim, mode, [hasTMACTAGroupSupport]>;
680+ : TMA_TENSOR_G2S_INTR<dim, mode,
681+ [callSubtarget<"hasTMABlackwellSupport">]>;
726682 }
727683}
728684defm TMA_G2S_TILE_GATHER4_2D : TMA_TENSOR_G2S_INTR<5, "tile_gather4",
729- [hasTMACTAGroupSupport ]>;
685+ [callSubtarget<"hasTMABlackwellSupport"> ]>;
730686
731687multiclass TMA_TENSOR_G2S_CTA_INTR<int dim, string mode, list<Predicate> pred = []> {
732688 defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
@@ -784,7 +740,8 @@ foreach dim = 3...5 in {
784740 : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w", [hasPTX<86>, hasSM<100>]>;
785741
786742 defm TMA_G2S_CTA_IM2COL_W_128_ # dim # "D"
787- : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", [hasTMACTAGroupSupport]>;
743+ : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128",
744+ [callSubtarget<"hasTMABlackwellSupport">]>;
788745}
789746defm TMA_G2S_CTA_TILE_GATHER4_2D : TMA_TENSOR_G2S_CTA_INTR<5, "tile_gather4",
790747 [hasPTX<86>, hasSM<100>]>;
@@ -835,7 +792,7 @@ foreach dim = 1...5 in {
835792 }
836793}
837794defm TMA_S2G_TILE_SCATTER4_2D : TMA_TENSOR_S2G_INTR<5, "tile_scatter4",
838- [hasTMACTAGroupSupport ]>;
795+ [callSubtarget<"hasTMABlackwellSupport"> ]>;
839796
840797def TMAReductionFlags : Operand<i32> {
841798 let PrintMethod = "printTmaReductionMode";
@@ -930,11 +887,11 @@ foreach dim = 3...5 in {
930887 foreach mode = ["im2col_w", "im2col_w_128"] in {
931888 defvar suffix = !toupper(mode) # "_" # dim # "D";
932889 defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode,
933- [hasTMACTAGroupSupport ]>;
890+ [callSubtarget<"hasTMABlackwellSupport"> ]>;
934891 }
935892}
936893defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4",
937- [hasTMACTAGroupSupport ]>;
894+ [callSubtarget<"hasTMABlackwellSupport"> ]>;
938895
939896//Prefetchu and Prefetch
940897
0 commit comments