Skip to content

Commit 29a8282

Browse files
committed
Fix AxisInfo rank mismatch for poison tensor pointers
Signed-off-by: Witold Dziurdz <[email protected]>
1 parent 2beb243 commit 29a8282

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,12 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
276276
getAxisInfo(ub::PoisonOp op,
277277
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
278278
unsigned rank = 1;
279-
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
279+
if (auto shape = dyn_cast<RankedTensorType>(op.getType())) {
280280
rank = shape.getRank();
281+
} else if (auto ptrTy = dyn_cast<PointerType>(op.getType())) {
282+
if (auto tensorType = dyn_cast<RankedTensorType>(ptrTy.getPointeeType()))
283+
rank = tensorType.getRank();
284+
}
281285

282286
// Poison values are never accessed, thus assume optimistic values.
283287
return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor),
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-assign-latencies=num-stages=3 -canonicalize | FileCheck %s
2+
3+
// Test that ub.poison producing a ptr<tensor> gets correct rank in AxisInfo
4+
// analysis (rank=2 for tensor<128x64>, not rank=1).
5+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
6+
// CHECK-LABEL: @test_poison_rank
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
8+
tt.func public @test_poison_rank(%arg0: !tt.ptr<f16>, %lb: i32, %ub: i32) {
9+
%c0_i32 = arith.constant 0 : i32
10+
%c1_i32 = arith.constant 1 : i32
11+
%c1_i64 = arith.constant 1 : i64
12+
%c128_i64 = arith.constant 128 : i64
13+
%c64_i64 = arith.constant 64 : i64
14+
15+
%0 = ub.poison : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
16+
17+
%1 = tt.make_tensor_ptr %arg0, [%c128_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
18+
19+
%result = scf.for %i = %lb to %ub step %c1_i32
20+
iter_args(%ptr = %0) -> !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> : i32 {
21+
22+
%advanced = tt.advance %ptr, [%c0_i32, %c0_i32] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
23+
24+
scf.yield %advanced : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
25+
}
26+
27+
tt.return
28+
}
29+
}

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,12 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
295295
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
296296
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
297297
// Poison values are never accessed, thus assume optimistic values.
298-
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
298+
Type type = op.getType();
299+
if (auto ptrTy = dyn_cast<triton::PointerType>(type)) {
300+
type = ptrTy.getPointeeType();
301+
}
302+
303+
if (auto shape = dyn_cast<mlir::ShapedType>(type)) {
299304
unsigned rank = shape.getRank();
300305
return AxisInfo(
301306
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),

0 commit comments

Comments
 (0)