Skip to content

Commit 23719b7

Browse files
authored
Revert "[Coalesce] Fix the default order to be row major " (#5744)
Reverts triton-lang/triton#5707 This causes some functional changes that I need to investigate
1 parent 0c7edf9 commit 23719b7

File tree

5 files changed

+8
-25
lines changed

5 files changed

+8
-25
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
3333
// Return true if the Load uses block pointer.
3434
bool isLoadFromTensorPtr(triton::LoadOp op);
3535

36-
// Gets the order of a tensor from its contiguity. Places the dimensions with
37-
// the largest contiguity as the inner most dimension. If the contiguity is
38-
// all ones, returns the order {dim - 1, dim - 2, ..., 0}
39-
SmallVector<unsigned, 4>
40-
getOrderFromContiguity(const SmallVector<int64_t> &contiguity);
36+
// Return an array of indices enumerating the elements of 'arr' in descending
37+
// order (so that result[i] is the index of the i-th largest element of 'arr')
38+
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr);
4139

4240
// Return the operand used to access the memory in the operation
4341
Value getMemAccessPtr(Operation *op);

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
3838
});
3939

4040
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
41-
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
41+
SmallVector<unsigned> order = argSort(contiguity);
4242
LDBG("order=[" << triton::join(order, ", ") << "]");
4343

4444
auto matchesShape = [&refTensorType](const Value &val) {
@@ -55,8 +55,8 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
5555
Value val = getMemAccessPtr(use);
5656
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
5757
continue;
58-
auto currOrder = getOrderFromContiguity(
59-
axisInfoAnalysis.getAxisInfo(val)->getContiguity());
58+
auto currOrder =
59+
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
6060
if (order == currOrder) {
6161
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
6262
memAccessesSameOrder.insert(use);

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) {
341341
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
342342
tt::AxisInfo::DimVectorT contiguity =
343343
axisInfo.getAxisInfo(src)->getContiguity();
344-
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
344+
SmallVector<unsigned> order = argSort(contiguity);
345345
unsigned currPerThread = getNumElementsPerThread(loadOp, order, axisInfo);
346346
SmallVector<unsigned> sizePerThread(order.size(), 1);
347347
sizePerThread[order[0]] = currPerThread;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,9 @@ bool isLoadFromTensorPtr(triton::LoadOp op) {
8989
return mlir::triton::isTensorPointerType(op.getPtr().getType());
9090
}
9191

92-
SmallVector<unsigned, 4>
93-
getOrderFromContiguity(const SmallVector<int64_t> &arr) {
92+
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr) {
9493
SmallVector<unsigned, 4> ret(arr.size());
9594
std::iota(ret.begin(), ret.end(), 0);
96-
std::reverse(ret.begin(), ret.end());
9795
std::stable_sort(ret.begin(), ret.end(),
9896
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
9997
return ret;

test/TritonGPU/coalesce.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,3 @@ module {
160160
tt.return
161161
}
162162
}
163-
164-
// -----
165-
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
166-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
167-
tt.func public @load_3D_contig_1(%arg: !tt.ptr<i8> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
168-
%50 = tt.splat %arg : !tt.ptr<i8> -> tensor<32x4x4x!tt.ptr<i8>, #blocked>
169-
// This checks that the pass picks the row-major ordering by default for elements with contiguity 1.
170-
// CHECK: #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
171-
// CHECK: tt.load %1 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
172-
%108 = tt.load %50 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
173-
tt.return
174-
}
175-
}

0 commit comments

Comments
 (0)