Skip to content

Commit 717997b

Browse files
ThomasRaouxmasahirawnhenry
authored
[Coalesce] Fix the default order to be row major (#5707) (#7143)
Try to reland: triton-lang/triton#5707 Taking over triton-lang/triton#4914 due to an inactivity As discussed there, when there are multiple "contiguity of 1" in the `contiguity` array, doing argsort on it means that the resulting `order` becomes ascending for those elements. In the unit test, `order = [2, 1, 0]` becomes `[0, 1, 2]`, which is odd. This convention seems arbitrary, so it is better to pick the row-major ordering by default in such case to be consistent with the rest of code. The current convention is "correct", but we get an additional `convert_layout`. Moreover, this order is inherited to the SMEM allocated during SWP, which could be problematic for other ops. For example, in my case I was getting the order `[4, 0, 1, 2, 3]` in SMEM for 5D blocked scales because only the innermost axis had a contiguity 4 while the rest were 1. @ThomasRaoux @pawelszczerbuk @Jokeren @rawnhenry Co-authored-by: masahi <[email protected]> Co-authored-by: Rawn Henry <[email protected]> Co-authored-by: Masahiro Masuda <[email protected]>
1 parent 3e14134 commit 717997b

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
3535
// Return true if the Load uses block pointer.
3636
bool isLoadFromTensorPtr(triton::LoadOp op);
3737

38-
// Return an array of indices enumerating the elements of 'arr' in descending
39-
// order (so that result[i] is the index of the i-th largest element of 'arr')
40-
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr);
38+
// Gets the order of a tensor from its contiguity. Places the dimensions with
39+
// the largest contiguity as the inner most dimension. If the contiguity is
40+
// all ones, returns the order {dim - 1, dim - 2, ..., 0}
41+
SmallVector<unsigned, 4>
42+
getOrderFromContiguity(const SmallVector<int64_t> &contiguity);
4143

4244
// Return the operand used to access the memory in the operation
4345
Value getMemAccessPtr(Operation *op);

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
3838
});
3939

4040
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
41-
SmallVector<unsigned> order = argSort(contiguity);
41+
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
4242
LDBG("order=[" << triton::join(order, ", ") << "]");
4343

4444
auto matchesShape = [&refTensorType](const Value &val) {
@@ -55,8 +55,8 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
5555
Value val = getMemAccessPtr(use);
5656
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
5757
continue;
58-
auto currOrder =
59-
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
58+
auto currOrder = getOrderFromContiguity(
59+
axisInfoAnalysis.getAxisInfo(val)->getContiguity());
6060
if (order == currOrder) {
6161
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
6262
memAccessesSameOrder.insert(use);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ bool isLoadFromTensorPtr(triton::LoadOp op) {
9292
return mlir::triton::isTensorPointerType(op.getPtr().getType());
9393
}
9494

95-
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr) {
95+
SmallVector<unsigned, 4>
96+
getOrderFromContiguity(const SmallVector<int64_t> &arr) {
9697
SmallVector<unsigned, 4> ret(arr.size());
9798
std::iota(ret.begin(), ret.end(), 0);
99+
std::reverse(ret.begin(), ret.end());
98100
std::stable_sort(ret.begin(), ret.end(),
99101
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
100102
return ret;

test/TritonGPU/coalesce.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,17 @@ tt.func @coalesce_poison(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1
199199
}
200200

201201
}
202+
203+
// -----
204+
205+
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
206+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
207+
tt.func public @load_3D_contig_1(%arg: !tt.ptr<i8> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
208+
%50 = tt.splat %arg : !tt.ptr<i8> -> tensor<32x4x4x!tt.ptr<i8>, #blocked>
209+
// This checks that the pass picks the row-major ordering by default for elements with contiguity 1.
210+
// CHECK: #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
211+
// CHECK: tt.load %1 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
212+
%108 = tt.load %50 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
213+
tt.return
214+
}
215+
}

0 commit comments

Comments
 (0)