Skip to content

Commit d3a28a6

Browse files
Fix AxisInfo handling of PoisonOp producing MemDesc (#8489) (#5501)
The PoisonOpAxisInfoVisitor incorrectly returns rank=1 for ub.poison operations producing pointer-to-tensor types like !tt.ptr<tensor<128x64xf16>>. This wrong rank propagates through unrealized_conversion_cast operations created during lowering, causing assertion failures in AxisInfo::join(). Fixes #5464
2 parents a93c725 + 29a8282 commit d3a28a6

File tree

4 files changed

+57
-2
lines changed

4 files changed

+57
-2
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,12 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
276276
getAxisInfo(ub::PoisonOp op,
277277
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
278278
unsigned rank = 1;
279-
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType()))
279+
if (auto shape = dyn_cast<RankedTensorType>(op.getType())) {
280280
rank = shape.getRank();
281+
} else if (auto ptrTy = dyn_cast<PointerType>(op.getType())) {
282+
if (auto tensorType = dyn_cast<RankedTensorType>(ptrTy.getPointeeType()))
283+
rank = tensorType.getRank();
284+
}
281285

282286
// Poison values are never accessed, thus assume optimistic values.
283287
return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor),
@@ -1229,6 +1233,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
12291233
return rhs;
12301234
if (rhs.getRank() == 0)
12311235
return lhs;
1236+
assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks");
12321237
DimVectorT contiguity;
12331238
DimVectorT divisibility;
12341239
DimVectorT constancy;

test/TritonGPU/pipeline-assign-latencies.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
11491149
}
11501150

11511151
// -----
1152+
// Test that ub.poison producing a memdesc does not get treated like a tensor
1153+
// value in AxisInfo analysis.
1154+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
1155+
#smem = #ttg.shared_memory
1156+
module attributes {"ttg.num-warps" = 4 : i32} {
1157+
tt.func public @minimal_crash(%lb: i32, %ub: i32) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> {
1158+
%c1 = arith.constant 1 : i32
1159+
%poison = ub.poison : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1160+
%normal = ttg.local_alloc : () -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1161+
%result = scf.for %i = %lb to %ub step %c1 iter_args(%current = %poison) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> : i32 {
1162+
scf.yield %normal : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1163+
}
1164+
tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
1165+
}
1166+
}
11521167

1168+
// -----
11531169
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
11541170
#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
11551171
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-assign-latencies=num-stages=3 -canonicalize | FileCheck %s
2+
3+
// Test that ub.poison producing a ptr<tensor> gets correct rank in AxisInfo
4+
// analysis (rank=2 for tensor<128x64>, not rank=1).
5+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
6+
// CHECK-LABEL: @test_poison_rank
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
8+
tt.func public @test_poison_rank(%arg0: !tt.ptr<f16>, %lb: i32, %ub: i32) {
9+
%c0_i32 = arith.constant 0 : i32
10+
%c1_i32 = arith.constant 1 : i32
11+
%c1_i64 = arith.constant 1 : i64
12+
%c128_i64 = arith.constant 128 : i64
13+
%c64_i64 = arith.constant 64 : i64
14+
15+
%0 = ub.poison : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
16+
17+
%1 = tt.make_tensor_ptr %arg0, [%c128_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
18+
19+
%result = scf.for %i = %lb to %ub step %c1_i32
20+
iter_args(%ptr = %0) -> !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> : i32 {
21+
22+
%advanced = tt.advance %ptr, [%c0_i32, %c0_i32] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
23+
24+
scf.yield %advanced : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
25+
}
26+
27+
tt.return
28+
}
29+
}

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,12 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
295295
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
296296
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
297297
// Poison values are never accessed, thus assume optimistic values.
298-
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
298+
Type type = op.getType();
299+
if (auto ptrTy = dyn_cast<triton::PointerType>(type)) {
300+
type = ptrTy.getPointeeType();
301+
}
302+
303+
if (auto shape = dyn_cast<mlir::ShapedType>(type)) {
299304
unsigned rank = shape.getRank();
300305
return AxisInfo(
301306
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),

0 commit comments

Comments
 (0)