Skip to content

Commit 4734af3

Browse files
authored
Fix AxisInfo handling of PoisonOp producing MemDesc (#8489)
AxisInfo analysis currently retrieves the rank from any `ShapedType` producing `PoisonOp`. This is a problem if the `PoisonOp` actually produces a `MemDesc`, since the value produced by the `PoisonOp` may flow into the same value as some other `MemDesc` producing operation, which will have been assigned the "pessimistic state" and have rank 1. When we attempt to join the two, the ranks will not match, potentially resulting in a crash. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 3f4ac9f commit 4734af3

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
276276
getAxisInfo(ub::PoisonOp op,
277277
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
278278
unsigned rank = 1;
279-
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType()))
279+
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
280280
rank = shape.getRank();
281281

282282
// Poison values are never accessed, thus assume optimistic values.
@@ -1227,6 +1227,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
12271227
return rhs;
12281228
if (rhs.getRank() == 0)
12291229
return lhs;
1230+
assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks");
12301231
DimVectorT contiguity;
12311232
DimVectorT divisibility;
12321233
DimVectorT constancy;

test/TritonGPU/pipeline-assign-latencies.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,3 +1147,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
11471147
tt.return
11481148
}
11491149
}
1150+
1151+
// -----
1152+
1153+
// Test that ub.poison producing a memdesc does not get treated like a tensor
1154+
// value in AxisInfo analysis.
1155+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
1156+
#smem = #ttg.shared_memory
1157+
module attributes {"ttg.num-warps" = 4 : i32} {
1158+
tt.func public @minimal_crash(%lb: i32, %ub: i32) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> {
1159+
%c1 = arith.constant 1 : i32
1160+
%poison = ub.poison : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1161+
%normal = ttg.local_alloc : () -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1162+
%result = scf.for %i = %lb to %ub step %c1 iter_args(%current = %poison) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> : i32 {
1163+
scf.yield %normal : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1164+
}
1165+
tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1166+
}
1167+
}

0 commit comments

Comments
 (0)