Skip to content

Commit face3d2

Browse files
authored
[LinearLayout] Fix crash with debug (triton-lang#6272)
This PR introduces stronger check for empty layouts in invertAndCompose to fix compiler crashes while running in debug mode.
1 parent 40f3945 commit face3d2

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

lib/Tools/LinearLayout.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,8 @@ LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) {
837837

838838
// We need names for the in/out dim of the flattened layout we're going to
839839
// read off from `m`. These could be anything, doesn't matter.
840+
assert(!A.getInDimNames().empty() &&
841+
"attempt to solve lstsq for empty layout");
840842
StringAttr inDim1D = *A.getInDimNames().begin();
841843
StringAttr outDim1D = *A.getOutDimNames().begin();
842844

@@ -927,9 +929,8 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const {
927929
auto BReduced = B.sublayout(BNonIdentityInDims, outDims);
928930

929931
// If one is empty, the other must be empty as well
930-
assert((AReduced == LinearLayout::empty()) ==
931-
(BReduced == LinearLayout::empty()));
932-
bool isEmpty = AReduced == LinearLayout::empty();
932+
assert((ANonIdentityInDims.empty()) == (BNonIdentityInDims.empty()));
933+
bool isEmpty = ANonIdentityInDims.empty();
933934

934935
auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced);
935936

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm --debug| FileCheck %s
2+
3+
// CHECK-LABEL: convert_identity
4+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
5+
#smem = #ttg.shared_memory
6+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
7+
tt.func public @convert_identity(%arg0: tensor<128x128xf16, #blocked>) attributes {noinline = false} {
8+
%1 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked>
9+
tt.return
10+
}
11+
}

0 commit comments

Comments
 (0)