Skip to content

Commit ae84826

Browse files
authored
[Revert] [Coalesce] Fix the default order to be row major (triton-lang#5707) triton-lang#7143 (triton-lang#7380)
Revert of triton-lang#7143 As per #832 (comment) Trying to resolve #832 on the release 3.4.x branch
1 parent f81f19a commit ae84826

File tree

4 files changed

+7
-25
lines changed

4 files changed

+7
-25
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
3535
// Return true if the Load uses block pointer.
3636
bool isLoadFromTensorPtr(triton::LoadOp op);
3737

38-
// Gets the order of a tensor from its contiguity. Places the dimensions with
39-
// the largest contiguity as the inner most dimension. If the contiguity is
40-
// all ones, returns the order {dim - 1, dim - 2, ..., 0}
41-
SmallVector<unsigned, 4>
42-
getOrderFromContiguity(const SmallVector<int64_t> &contiguity);
38+
// Return an array of indices enumerating the elements of 'arr' in descending
39+
// order (so that result[i] is the index of the i-th largest element of 'arr')
40+
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr);
4341

4442
// Return the operand used to access the memory in the operation
4543
Value getMemAccessPtr(Operation *op);

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
3838
});
3939

4040
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
41-
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
41+
SmallVector<unsigned> order = argSort(contiguity);
4242
LDBG("order=[" << triton::join(order, ", ") << "]");
4343

4444
auto matchesShape = [&refTensorType](const Value &val) {
@@ -55,8 +55,8 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
5555
Value val = getMemAccessPtr(use);
5656
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
5757
continue;
58-
auto currOrder = getOrderFromContiguity(
59-
axisInfoAnalysis.getAxisInfo(val)->getContiguity());
58+
auto currOrder =
59+
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
6060
if (order == currOrder) {
6161
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
6262
memAccessesSameOrder.insert(use);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,9 @@ bool isLoadFromTensorPtr(triton::LoadOp op) {
9292
return mlir::triton::isTensorPointerType(op.getPtr().getType());
9393
}
9494

95-
SmallVector<unsigned, 4>
96-
getOrderFromContiguity(const SmallVector<int64_t> &arr) {
95+
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr) {
9796
SmallVector<unsigned, 4> ret(arr.size());
9897
std::iota(ret.begin(), ret.end(), 0);
99-
std::reverse(ret.begin(), ret.end());
10098
std::stable_sort(ret.begin(), ret.end(),
10199
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
102100
return ret;

test/TritonGPU/coalesce.mlir

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,3 @@ tt.func @coalesce_poison(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1
199199
}
200200

201201
}
202-
203-
// -----
204-
205-
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
206-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
207-
tt.func public @load_3D_contig_1(%arg: !tt.ptr<i8> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
208-
%50 = tt.splat %arg : !tt.ptr<i8> -> tensor<32x4x4x!tt.ptr<i8>, #blocked>
209-
// This checks that the pass picks the row-major ordering by default for elements with contiguity 1.
210-
// CHECK: #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
211-
// CHECK: tt.load %1 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
212-
%108 = tt.load %50 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
213-
tt.return
214-
}
215-
}

0 commit comments

Comments
 (0)