Skip to content

Commit 2bc0153

Browse files
aakhundovjbdalido
authored andcommitted
[BACKEND] Fix ProgramPoint passing in AxisInfoAnalysis (triton-lang#5181)
Fixes triton-lang#5122. The `ProgramPoint` [here](https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/lib/Analysis/AxisInfo.cpp#L1087) is created on the stack. Then its address is [passed](https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/lib/Analysis/AxisInfo.cpp#L1088-L1089) to the MLIR `SparseAnalysis` code, where it is [added as a dependency](https://github.com/llvm/llvm-project/blob/33ff9e43b4c5bdc3da31c6b11ad51d35a69bec5f/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L311) and later [dereferenced](https://github.com/llvm/llvm-project/blob/33ff9e43b4c5bdc3da31c6b11ad51d35a69bec5f/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L90). By the time the `ProramPoint` is dereferenced in the `AbstractSparseForwardDataFlowAnalysis::visit`, the `AxisInfoAnalysis::visitForOpInductionVar` will have finished and the `ProgramPoint` stack variable destroyed. This leads to a segfault (which can be reproed on the base rev with the lit test added in this PR). The code modified in this PR was originally added in triton-lang#4927, in conjunction with updating the `llvm-project` hash to `b5cc222d7429`. However, as noted in llvm/llvm-project#110344 (the `llvm-project` PR that has made the refactoring prompting the `AxisInfo.cpp` change in triton-lang#4927): > For dense forward data-flow analysis and other analysis (except dense backward data-flow analysis), the program point corresponding to the original operation can be obtained by `getProgramPointAfter(op)` As the `AxisInfoAnalysis` (in Triton) inherits from `SparseForwardDataFlowAnalysis` (in MLIR), in this PR we follow the above which resolves the segfault issue (as the `ProgramPoint` is now stored in the instance-level state of the pass). P.S. The lit test added in this PR is not exactly minimal. However, I did my best to minimize it starting from the 400-line repro TTGIR in
1 parent ba5764a commit 2bc0153

File tree

2 files changed

+5
-39
lines changed

2 files changed

+5
-39
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,10 +1102,15 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11021102
void AxisInfoAnalysis::visitForOpInductionVar(
11031103
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
11041104
ProgramPoint *programPoint = getProgramPointAfter(op);
1105+
<<<<<<< HEAD
11051106
const auto &lb =
11061107
getLatticeElementFor(programPoint, op.getLowerBound())->getValue();
11071108
const auto &step =
11081109
getLatticeElementFor(programPoint, op.getStep())->getValue();
1110+
=======
1111+
auto lb = getLatticeElementFor(programPoint, op.getLowerBound())->getValue();
1112+
auto step = getLatticeElementFor(programPoint, op.getStep())->getValue();
1113+
>>>>>>> 24c0fe47a ([BACKEND] Fix ProgramPoint passing in AxisInfoAnalysis (#5181))
11091114

11101115
AxisInfo::DimVectorT knownContiguity(1, 1);
11111116
AxisInfo::DimVectorT knownDivisibility(1, 1);

test/TritonGPU/coalesce.mlir

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -160,42 +160,3 @@ module {
160160
tt.return
161161
}
162162
}
163-
164-
// -----
165-
166-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
167-
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
168-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
169-
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
170-
171-
// CHECK: [[COALESCED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}>
172-
173-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
174-
175-
// CHECK: @coalesce_poison
176-
tt.func @coalesce_poison(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i1) {
177-
%c0_i32 = arith.constant 0 : i32
178-
%c1_i32 = arith.constant 1 : i32
179-
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked>
180-
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1>
181-
%2 = ttg.convert_layout %1 : tensor<128xi32, #blocked1> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
182-
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2>
183-
%4 = ttg.convert_layout %3 : tensor<128x1xi32, #blocked2> -> tensor<128x1xi32, #blocked3>
184-
%5 = tt.broadcast %4 {axis = 1 : i32} : tensor<128x1xi32, #blocked3> -> tensor<128x64xi32, #blocked3>
185-
%6 = ttg.convert_layout %5 : tensor<128x64xi32, #blocked3> -> tensor<128x64xi32, #blocked>
186-
%7 = tt.addptr %0, %6 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi32, #blocked>
187-
188-
%8 = ub.poison : tensor<128x64x!tt.ptr<f16>, #blocked>
189-
// CHECK: scf.if
190-
%9 = scf.if %arg2 -> (tensor<128x64x!tt.ptr<f16>, #blocked>) {
191-
scf.yield %8 : tensor<128x64x!tt.ptr<f16>, #blocked>
192-
} else {
193-
scf.yield %7 : tensor<128x64x!tt.ptr<f16>, #blocked>
194-
}
195-
// CHECK: [[PTR:%.*]] = ttg.convert_layout %{{.*}} : tensor<128x64x!tt.ptr<f16>, #{{.*}}> -> tensor<128x64x!tt.ptr<f16>, [[COALESCED_LAYOUT]]>
196-
// CHECK-NEXT: tt.load [[PTR]]
197-
%10 = tt.load %9 : tensor<128x64x!tt.ptr<f16>, #blocked>
198-
tt.return
199-
}
200-
201-
}

0 commit comments

Comments
 (0)