@@ -51,7 +51,7 @@ tt.func @one_dep_async(%lb : index, %ub : index, %step : index,
5151 scf.for %iv = %lb to %ub step %step : index {
5252 %a = tt.load %a_ptr_init {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : tensor <128 x32 x!tt.ptr <f16 >, #A >
5353 " use" (%a ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
54- }
54+ } { tt.scheduled_max_stage = 2 : i32 }
5555 tt.return
5656}
5757}
@@ -75,7 +75,7 @@ tt.func @different_use_stages(%lb : index, %ub : index, %step : index,
7575 %a = tt.load %a_ptr_init {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : tensor <128 x32 x!tt.ptr <f16 >, #A >
7676 " use1" (%a ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
7777 " use2" (%a ) {loop.cluster = 0 : i32 , loop.stage = 3 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
78- }
78+ } { tt.scheduled_max_stage = 3 : i32 }
7979 tt.return
8080}
8181}
@@ -106,7 +106,7 @@ tt.func @used_by_if_yield(%lb : index, %ub : index, %step : index,
106106 scf.yield %init_a : tensor <128 x32 xf16 , #A >
107107 } {loop.cluster = 0 : i32 , loop.stage = 2 : i32 }
108108 " use" (%a_if ) {loop.cluster = 0 : i32 , loop.stage = 3 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
109- }
109+ } { tt.scheduled_max_stage = 3 : i32 }
110110 tt.return
111111}
112112}
@@ -124,7 +124,7 @@ tt.func @dist1_load(%lb : index, %ub : index, %step : index,
124124 %a = tt.load %a_ptr_init {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : tensor <128 x32 x!tt.ptr <f16 >, #A >
125125 " use" (%a ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
126126 scf.yield %a : tensor <128 x32 xf16 , #A >
127- }
127+ } { tt.scheduled_max_stage = 2 : i32 }
128128 tt.return
129129}
130130}
@@ -142,7 +142,7 @@ tt.func @one_dep_sync(%lb : index, %ub : index, %step : index,
142142 scf.for %iv = %lb to %ub step %step : index {
143143 %a = tt.load %a_ptr_init {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : tensor <1 x!tt.ptr <f16 >, #A >
144144 " use" (%a ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <1 xf16 , #A >) -> ()
145- }
145+ } { tt.scheduled_max_stage = 2 : i32 }
146146 tt.return
147147}
148148}
@@ -183,7 +183,7 @@ tt.func @one_dep_local_alloc(%lb : index, %ub : index, %step : index,
183183 %a_alloc = ttg.local_alloc %a {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> !ttg.memdesc <128 x32 xf16 , #shared , #ttg.shared_memory , mutable >
184184 %a_load = ttg.local_load %a_alloc {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : !ttg.memdesc <128 x32 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <128 x32 xf16 , #A >
185185 " use" (%a_load ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
186- }
186+ } { tt.scheduled_max_stage = 2 : i32 }
187187 tt.return
188188}
189189}
@@ -214,7 +214,7 @@ tt.func @one_load_group(%lb : index, %ub : index, %step : index,
214214 %b = tt.load %a_ptr_init {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : tensor <128 x32 x!tt.ptr <f32 >, #A >
215215 " use1" (%a ){loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf32 , #A >) -> ()
216216 " use2" (%b ){loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf32 , #A >) -> ()
217- }
217+ } { tt.scheduled_max_stage = 2 : i32 }
218218 tt.return
219219}
220220}
@@ -255,7 +255,7 @@ tt.func @two_load_groups(%lb : index, %ub : index, %step : index,
255255 " use1" (%a ){loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf32 , #A >) -> ()
256256 " use2" (%b ){loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf32 , #A >) -> ()
257257 " use3" (%c ){loop.cluster = 0 : i32 , loop.stage = 3 : i32 } : (tensor <128 x32 xf32 , #A >) -> ()
258- }
258+ } { tt.scheduled_max_stage = 3 : i32 }
259259 tt.return
260260}
261261}
@@ -304,7 +304,7 @@ tt.func @dependent_loads(%lb : index, %ub : index, %step : index,
304304 %b = " pointerize" (%a ) {loop.cluster = 2 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf32 , #A >) -> tensor <128 x32 x!tt.ptr <f32 >, #A >
305305 %c = tt.load %b {loop.cluster = 2 : i32 , loop.stage = 2 : i32 } : tensor <128 x32 x!tt.ptr <f32 >, #A >
306306 " use1" (%c ){loop.cluster = 0 : i32 , loop.stage = 4 : i32 } : (tensor <128 x32 xf32 , #A >) -> ()
307- }
307+ } { tt.scheduled_max_stage = 4 : i32 }
308308 tt.return
309309}
310310}
@@ -361,7 +361,7 @@ tt.func @dependent_loads_asymmetric(%lb : index, %ub : index, %step : index,
361361 %b = " pointerize" (%a ) {loop.cluster = 2 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf32 , #A >) -> tensor <128 x32 x!tt.ptr <f32 >, #A >
362362 %c = tt.load %b {loop.cluster = 2 : i32 , loop.stage = 2 : i32 } : tensor <128 x32 x!tt.ptr <f32 >, #A >
363363 " use1" (%c ){loop.cluster = 0 : i32 , loop.stage = 5 : i32 } : (tensor <128 x32 xf32 , #A >) -> ()
364- }
364+ } { tt.scheduled_max_stage = 5 : i32 }
365365 tt.return
366366}
367367}
@@ -379,7 +379,7 @@ tt.func @unused_load(%lb : index, %ub : index, %step : index,
379379 // CHECK: dummy
380380 %a = tt.load %a_ptr_init {loop.cluster = 0 : i32 , loop.stage = 1 : i32 } : tensor <128 x32 x!tt.ptr <f32 >, #A >
381381 " dummy" () : () -> ()
382- }
382+ } { tt.scheduled_max_stage = 1 : i32 }
383383 tt.return
384384}
385385}
@@ -434,7 +434,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
434434 %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x128 xf16 , #blocked1 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >
435435 %acc_res = ttng.warp_group_dot %A_sh , %B_sh , %acc {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory > * !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory > -> tensor <128 x128 xf32 , #mma >
436436 scf.yield %acc_res : tensor <128 x128 xf32 , #mma >
437- }
437+ } { tt.scheduled_max_stage = 2 : i32 }
438438 %res_f16 = arith.truncf %res : tensor <128 x128 xf32 , #mma > to tensor <128 x128 xf16 , #mma >
439439 tt.return %res_f16 : tensor <128 x128 xf16 , #mma >
440440 }
@@ -489,7 +489,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
489489 %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x128 xf16 , #blocked1 >) -> !ttg.memdesc <128 x128 xf16 , #shared1 , #ttg.shared_memory >
490490 %acc_res = ttng.warp_group_dot %A_sh , %B_sh , %acc {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory > * !ttg.memdesc <128 x128 xf16 , #shared1 , #ttg.shared_memory > -> tensor <128 x128 xf32 , #mma >
491491 scf.yield %acc_res : tensor <128 x128 xf32 , #mma >
492- }
492+ } { tt.scheduled_max_stage = 2 : i32 }
493493 tt.return %res : tensor <128 x128 xf32 , #mma >
494494 }
495495}
@@ -555,7 +555,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
555555 %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x128 xf16 , #blocked1 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >
556556 %acc_res = ttng.warp_group_dot %A_sh , %B_sh , %acc {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory > * !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory > -> tensor <128 x128 xf32 , #mma >
557557 scf.yield %acc_res : tensor <128 x128 xf32 , #mma >
558- }
558+ } { tt.scheduled_max_stage = 2 : i32 }
559559 %res_f16 = arith.truncf %res : tensor <128 x128 xf32 , #mma > to tensor <128 x128 xf16 , #mma >
560560 tt.return %res_f16 : tensor <128 x128 xf16 , #mma >
561561 }
@@ -614,7 +614,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
614614 ttng.tc_gen5_mma %A_sh , %B_sh , %acc_tm , %true , %true {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (!ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory >, i1 , i1 ) -> ()
615615 %acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory > -> tensor <128 x128 xf32 , #blocked >
616616 scf.yield %acc_res : tensor <128 x128 xf32 , #blocked >
617- }
617+ } { tt.scheduled_max_stage = 2 : i32 }
618618 %res_f16 = arith.truncf %res : tensor <128 x128 xf32 , #blocked > to tensor <128 x128 xf16 , #blocked >
619619 tt.return %res_f16 : tensor <128 x128 xf16 , #blocked >
620620 }
@@ -669,7 +669,7 @@ tt.func @tma_load_lowering(%lb : index, %ub : index, %step : index,
669669 scf.for %iv = %lb to %ub step %step : index {
670670 %a = tt.experimental_descriptor_load %desc [%offs , %offs ] {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : !tt.tensordesc <tensor <128 x32 xf16 >> -> tensor <128 x32 xf16 , #A >
671671 " use" (%a ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
672- }
672+ } { tt.scheduled_max_stage = 2 : i32 }
673673 tt.return
674674}
675675}
@@ -725,7 +725,7 @@ tt.func @tma_gather_lowering(%lb : index, %ub : index, %step : index,
725725 scf.for %iv = %lb to %ub step %step : index {
726726 %a = tt.experimental_descriptor_gather %desc [%x , %y ] {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : (!tt.tensordesc <tensor <1 x128 xf32 >>, tensor <32 xi32 , #offsets >, i32 ) -> tensor <32 x128 xf32 , #A >
727727 " use" (%a ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <32 x128 xf32 , #A >) -> ()
728- }
728+ } { tt.scheduled_max_stage = 2 : i32 }
729729 tt.return
730730}
731731}
@@ -760,7 +760,7 @@ tt.func @tma_reuse_barrier(%lb : index, %ub : index, %step : index,
760760 " use2" (%b ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
761761 %c = tt.experimental_descriptor_load %descC [%offs , %offs ] {loop.cluster = 2 : i32 , loop.stage = 0 : i32 } : !tt.tensordesc <tensor <128 x32 xf16 >> -> tensor <128 x32 xf16 , #A >
762762 " use3" (%c ) {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> ()
763- }
763+ } { tt.scheduled_max_stage = 2 : i32 }
764764 tt.return
765765}
766766}
@@ -798,7 +798,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
798798 %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x128 xf16 , #blocked1 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory >
799799 %acc_res = ttng.warp_group_dot %A_sh , %B_sh , %acc {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory > * !ttg.memdesc <128 x128 xf16 , #shared , #ttg.shared_memory > -> tensor <128 x128 xf32 , #mma >
800800 scf.yield %acc_res : tensor <128 x128 xf32 , #mma >
801- }
801+ } { tt.scheduled_max_stage = 2 : i32 }
802802 %res_f16 = arith.truncf %res : tensor <128 x128 xf32 , #mma > to tensor <128 x128 xf16 , #mma >
803803 tt.return %res_f16 : tensor <128 x128 xf16 , #mma >
804804 }
@@ -833,7 +833,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
833833 scf.for %iv = %lb to %ub step %step : index {
834834 %desc = tt.make_tensor_descriptor %A , [%shape_x , %shape_y ], [%strides_x , %strides_y ] {loop.cluster = 0 : i32 , loop.stage = 1 : i32 } : <f16 >, <tensor <128 x128 xf16 >>
835835 " use" (%desc ) {loop.cluster = 0 : i32 , loop.stage = 1 : i32 } : (!tt.tensordesc <tensor <128 x128 xf16 >>) -> ()
836- }
836+ } { tt.scheduled_max_stage = 1 : i32 }
837837 tt.return
838838 }
839839}
@@ -879,7 +879,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
879879 ttng.tc_gen5_mma_scaled %A_sh , %B_sh , %acc_tm , %A_sc_sh , %B_sc_sh , %true , %true lhs = e5m2 rhs = e5m2 {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (!ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory >, !ttg.memdesc <1 x2 x32 x4 x4 xi8 , #shared1 , #smem >, !ttg.memdesc <1 x2 x32 x4 x4 xi8 , #shared1 , #smem >, i1 , i1 ) -> ()
880880 %acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory > -> tensor <128 x128 xf32 , #blocked >
881881 scf.yield %acc_res : tensor <128 x128 xf32 , #blocked >
882- }
882+ } { tt.scheduled_max_stage = 2 : i32 }
883883 tt.return %res : tensor <128 x128 xf32 , #blocked >
884884 }
885885}
0 commit comments