Skip to content

Commit d71421d

Browse files
authored
[Backend] Try to fix infinite loop in membar (#5973)
1 parent 4cbf3c2 commit d71421d

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

lib/Analysis/Membar.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ void MembarAnalysis::resolve(FunctionOpInterface funcOp,
6060
// the outputBlockInfo, we skip the successors
6161
continue;
6262
}
63-
// Update the current block
64-
outputBlockInfoMap[block].join(inputBlockInfo);
63+
// Update the current block. The block transfer function is not monotonic,
64+
// so overwrite the output state entirely.
65+
outputBlockInfoMap[block] = inputBlockInfo;
6566
// Update the successors
6667
for (auto *successor : successors) {
6768
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);

test/Analysis/test-membar.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,44 @@ tt.func @tma_special_cases_cf(%arg1: !tt.ptr<i8, 0>, %i1 : i1, %arg2: tensor<256
828828
tt.return %t : tensor<256x64xf16, #blocked>
829829
}
830830
}
831+
832+
// -----
833+
834+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
835+
#shared = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
836+
#smem = #ttg.shared_memory
837+
838+
module attributes {"ttg.num-warps" = 4 : i32} {
839+
840+
// CHECK-LABEL: @direct_backedge_within_loop
841+
tt.func @direct_backedge_within_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>, %arg5: i1) {
842+
// CHECK-NEXT: constant
843+
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked>
844+
// CHECK-NEXT: local_alloc
845+
%0 = ttg.local_alloc %cst : (tensor<128x32xf16, #blocked>) -> !ttg.memdesc<128x32xf16, #shared, #smem>
846+
// CHECK-NEXT: barrier
847+
// CHECK-NEXT: local_load
848+
%1 = ttg.local_load %0 : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #blocked>
849+
// CHECK-NEXT: br
850+
cf.br ^bb1(%arg0, %0 : index, !ttg.memdesc<128x32xf16, #shared, #smem>)
851+
^bb1(%2: index, %3: !ttg.memdesc<128x32xf16, #shared, #smem>):
852+
cf.cond_br %arg5, ^bb2, ^bb3
853+
// CHECK: ^bb2:
854+
^bb2:
855+
// CHECK-NEXT: barrier
856+
// CHECK-NEXT: local_alloc
857+
%4 = ttg.local_alloc %cst : (tensor<128x32xf16, #blocked>) -> !ttg.memdesc<128x32xf16, #shared, #smem>
858+
// CHECK-NEXT: br
859+
cf.br ^bb1(%arg1, %4 : index, !ttg.memdesc<128x32xf16, #shared, #smem>)
860+
// CHECK: ^bb3
861+
^bb3:
862+
// CHECK-NEXT: barrier
863+
// CHECK-NEXT: local_load
864+
%5 = ttg.local_load %3 : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #blocked>
865+
// CHECK-NEXT: cond_br
866+
cf.cond_br %arg5, ^bb3, ^bb4
867+
^bb4:
868+
tt.return
869+
}
870+
871+
}

0 commit comments

Comments
 (0)