@@ -828,3 +828,44 @@ tt.func @tma_special_cases_cf(%arg1: !tt.ptr<i8, 0>, %i1 : i1, %arg2: tensor<256
828828 tt.return %t : tensor <256 x64 xf16 , #blocked >
829829}
830830}
831+
832+ // -----
833+
834+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
835+ #shared = #ttg.swizzled_shared <{vec = 2 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ]}>
836+ #smem = #ttg.shared_memory
837+
838+ module attributes {" ttg.num-warps" = 4 : i32 } {
839+
840+ // CHECK-LABEL: @direct_backedge_within_loop
841+ tt.func @direct_backedge_within_loop (%arg0: index , %arg1: index , %arg2: index , %arg3: !tt.ptr <f16 >, %arg4: !tt.ptr <f16 >, %arg5: i1 ) {
842+ // CHECK-NEXT: constant
843+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x32 xf16 , #blocked >
844+ // CHECK-NEXT: local_alloc
845+ %0 = ttg.local_alloc %cst : (tensor <128 x32 xf16 , #blocked >) -> !ttg.memdesc <128 x32 xf16 , #shared , #smem >
846+ // CHECK-NEXT: barrier
847+ // CHECK-NEXT: local_load
848+ %1 = ttg.local_load %0 : !ttg.memdesc <128 x32 xf16 , #shared , #smem > -> tensor <128 x32 xf16 , #blocked >
849+ // CHECK-NEXT: br
850+ cf.br ^bb1 (%arg0 , %0 : index , !ttg.memdesc <128 x32 xf16 , #shared , #smem >)
851+ ^bb1 (%2: index , %3: !ttg.memdesc <128 x32 xf16 , #shared , #smem >):
852+ cf.cond_br %arg5 , ^bb2 , ^bb3
853+ // CHECK: ^bb2:
854+ ^bb2 :
855+ // CHECK-NEXT: barrier
856+ // CHECK-NEXT: local_alloc
857+ %4 = ttg.local_alloc %cst : (tensor <128 x32 xf16 , #blocked >) -> !ttg.memdesc <128 x32 xf16 , #shared , #smem >
858+ // CHECK-NEXT: br
859+ cf.br ^bb1 (%arg1 , %4 : index , !ttg.memdesc <128 x32 xf16 , #shared , #smem >)
860+ // CHECK: ^bb3
861+ ^bb3 :
862+ // CHECK-NEXT: barrier
863+ // CHECK-NEXT: local_load
864+ %5 = ttg.local_load %3 : !ttg.memdesc <128 x32 xf16 , #shared , #smem > -> tensor <128 x32 xf16 , #blocked >
865+ // CHECK-NEXT: cond_br
866+ cf.cond_br %arg5 , ^bb3 , ^bb4
867+ ^bb4 :
868+ tt.return
869+ }
870+
871+ }
0 commit comments