@@ -58,3 +58,72 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
5858 tt.return
5959 }
6060}
61+
62+ // -----
63+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 32 , transposed = false , elementBitWidth = 8 }>
64+ #shared1 = #ttg.nvmma_shared <{swizzlingByteWidth = 32 , transposed = true , elementBitWidth = 8 }>
65+ #shared2 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ], CTAsPerCGA = [1 ], CTASplitNum = [1 ], CTAOrder = [0 ]}>
66+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 64 , unpacked = true >
67+ #smem = #ttg.shared_memory
68+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
69+ // CHECK-LABEL: tcgen5_with_commit
70+ tt.func @tcgen5_with_commit (
71+ // CHECK: [[BARRIER1:%.*]]: !ttg.memdesc<1xi64, #shared
72+ %barrier: !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >,
73+ // CHECK: [[BARRIER_PRED:%.*]]: i1,
74+ %barrierPred: i1 ,
75+ // CHECK: [[A_SMEM:%.*]]: !ttg.memdesc<128x128xf8E5M2
76+ %a: !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
77+ %b: !ttg.memdesc <128 x256 xf8 E5 M2 , #shared1 , #ttg.shared_memory >,
78+ %c: !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >) {
79+ %barrier2 = ttg.local_alloc : () -> !ttg.memdesc <2 x1 xi64 , #shared2 , #smem , mutable >
80+ %c0_i32 = arith.constant 0 : i32
81+ // CHECK: [[TRUE:%.*]] = arith.constant true
82+ // CHECK: [[BARRIER_SLICE:%.*]] = ttg.memdesc_index
83+ // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[BARRIER1]][[[BARRIER_PRED]]], [[BARRIER_SLICE]][[[TRUE]]]
84+ %accUse = arith.constant false
85+ %pred = arith.constant true
86+ ttng.tc_gen5_mma %a , %b , %c , %accUse , %pred {is_async } :
87+ !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
88+ !ttg.memdesc <128 x256 xf8 E5 M2 , #shared1 , #ttg.shared_memory >,
89+ !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >
90+ ttng.tc_gen5_commit %barrier , %barrierPred : !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >
91+ %barrier_slice = ttg.memdesc_index %barrier2 [%c0_i32 ] : !ttg.memdesc <2 x1 xi64 , #shared2 , #smem , mutable > -> !ttg.memdesc <1 xi64 , #shared2 , #smem , mutable , 2 x1 >
92+ ttng.tc_gen5_commit %barrier_slice : !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable , 2 x1 >
93+
94+ ttng.tc_gen5_mma %a , %b , %c , %accUse , %pred {is_async } :
95+ !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
96+ !ttg.memdesc <128 x256 xf8 E5 M2 , #shared1 , #ttg.shared_memory >,
97+ !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >
98+
99+ %random_pred = arith.cmpi eq , %barrierPred , %pred : i1
100+ scf.if %random_pred {
101+ ttng.tc_gen5_mma %a , %b , %c , %accUse , %pred {is_async } :
102+ !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
103+ !ttg.memdesc <128 x256 xf8 E5 M2 , #shared1 , #ttg.shared_memory >,
104+ !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >
105+ }
106+ // This commit should not be merged into any of two mma ops above
107+ // CHECK: tc_gen5_commit
108+ ttng.tc_gen5_commit %barrier , %barrierPred : !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >
109+
110+ // The mma predicate is not a constant true. The commit op should not be merged
111+ // CHECK: tc_gen5_commit
112+ ttng.tc_gen5_mma %a , %b , %c , %accUse , %random_pred {is_async } :
113+ !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
114+ !ttg.memdesc <128 x256 xf8 E5 M2 , #shared1 , #ttg.shared_memory >,
115+ !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >
116+ ttng.tc_gen5_commit %barrier : !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >
117+
118+ // There is an impure op between mma and commit ops. Do not allow merging in such cases.
119+ // CHECK: tc_gen5_commit
120+ ttng.tc_gen5_mma %a , %b , %c , %accUse , %pred {is_async } :
121+ !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
122+ !ttg.memdesc <128 x256 xf8 E5 M2 , #shared1 , #ttg.shared_memory >,
123+ !ttg.memdesc <128 x256 xf32 , #tmem , #ttng.tensor_memory , mutable >
124+ ttng.wait_barrier %barrier , %c0_i32 : !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >
125+ ttng.tc_gen5_commit %barrier : !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >
126+
127+ tt.return
128+ }
129+ }
0 commit comments