Skip to content

Commit c172d53

Browse files
authored
[BACKEND] Generalise maybeDeduplicate to all layouts (#8492)
We had a subtle asymmetry here that was producing different PTX for the same layout. We now generalise this pass to work with any layout and we drop a few restrictions the previous pass had.
1 parent 8f5aa60 commit c172d53

File tree

2 files changed

+30
-93
lines changed

2 files changed

+30
-93
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 28 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
5757
// computation is eliminated.
5858
SmallVector<Value> maybeDeduplicate(SourceOp op,
5959
SmallVector<Value> resultVals) const {
60+
auto ctx = op.getContext();
6061
if (!isMemoryEffectFree(op))
6162
// the op has side effects: can't dedup
6263
return resultVals;
@@ -65,104 +66,45 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
6566
// there must be exactly 1 result
6667
return resultVals;
6768
Value result = results[0];
68-
Type type = result.getType();
69-
if (!type)
70-
return resultVals;
71-
RankedTensorType rtType = dyn_cast<RankedTensorType>(type);
69+
RankedTensorType rtType = dyn_cast<RankedTensorType>(result.getType());
7270
if (!rtType)
7371
// the result must be a tensor
7472
return resultVals;
75-
Attribute encoding = rtType.getEncoding();
76-
if (!encoding)
77-
// encoding not available
78-
return resultVals;
79-
Attribute baseEncoding = encoding;
80-
if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
81-
isa<AMDWmmaEncodingAttr>(baseEncoding))
82-
// TODO: this logic seems incorrect for mfma and wmma layout. Skip for
83-
// now. We saw mismatches for some flash-attention and dot tests on AMD
84-
// backend. Note that this logic works for sliced layout whose parent is
85-
// mfma layout. Therefore, this is not combined with the following check.
86-
return resultVals;
87-
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))
88-
baseEncoding = sliced.getParent();
89-
if (isa<LinearEncodingAttr, DotOperandEncodingAttr>(baseEncoding)) {
90-
// TODO: this logic seems incorrect for mma layout. Skip for now.
91-
// The following test crashes and some other miscompile:
92-
// test_core::test_fp8_dot_acc
93-
return resultVals;
94-
}
9573

96-
SmallVector<unsigned> elemsPerThread = getElemsPerThread(rtType);
97-
int rank = elemsPerThread.size();
98-
if (product<unsigned>(elemsPerThread) != resultVals.size())
99-
return resultVals;
74+
// Bail out if we don't have the constancy analysis
10075
AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result);
10176
if (!axisInfo)
102-
// axis info (e.g., constancy) not available
103-
return resultVals;
104-
SmallVector<unsigned> contigPerThread = getContigPerThread(rtType);
105-
if (rank != contigPerThread.size())
10677
return resultVals;
107-
10878
SmallVector<int64_t> constancy = axisInfo->getConstancy();
109-
if (rank != constancy.size())
110-
return resultVals;
111-
bool hasConstancy = false;
112-
for (int i = 0; i < rank; ++i) {
113-
if (constancy[i] > contigPerThread[i]) {
114-
if (constancy[i] % contigPerThread[i] != 0)
115-
// constancy is not evenly covered by contigPerThread
116-
return resultVals;
117-
// can't move the values across different
118-
// "contigPerThread"-sized blocks
119-
constancy[i] = contigPerThread[i];
120-
}
121-
if (elemsPerThread[i] < 1 || constancy[i] < 1)
122-
return resultVals;
123-
if (!(elemsPerThread[i] % constancy[i] == 0 ||
124-
constancy[i] % elemsPerThread[i] == 0))
125-
// either the constancy along each dimension must fit
126-
// into the elemsPerThread or the other way around
127-
return resultVals;
128-
if (constancy[i] > 1)
129-
hasConstancy = true;
130-
}
131-
if (!hasConstancy)
132-
// nothing to deduplicate
133-
return resultVals;
13479

135-
if (rank > 1) {
136-
// reorder the shape and constancy vectors by the axis order:
137-
// from the fastest-changing to the smallest-changing axis
138-
SmallVector<unsigned> order = getOrder(rtType);
139-
if (rank != order.size())
140-
return resultVals;
141-
elemsPerThread = applyPermutation(elemsPerThread, order);
142-
constancy = applyPermutation(constancy, order);
143-
}
80+
if (llvm::all_of(constancy, [](int64_t c) { return c == 1; }))
81+
return resultVals;
14482

145-
SmallVector<unsigned> strides(rank, 1);
146-
for (int i = 1; i < rank; ++i) {
147-
strides[i] = strides[i - 1] * elemsPerThread[i - 1];
148-
}
149-
SmallVector<Value> dedupResultVals;
150-
dedupResultVals.reserve(resultVals.size());
151-
for (int i = 0; i < resultVals.size(); ++i) {
152-
// each coordinate of the orig_idx is "coarsened" using the
153-
// constancy along this dimension: the resulting dedup_idx
154-
// points to the reused value in the original resultsVal
155-
int orig_idx = i;
156-
int dedup_idx = 0;
157-
for (int j = 0; j < rank; ++j) {
158-
int coord_j = orig_idx % elemsPerThread[j];
159-
dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j];
160-
orig_idx /= elemsPerThread[j];
83+
// We zero out the bases that are constant
84+
auto kReg = StringAttr::get(ctx, "register");
85+
auto ll = toLinearLayout(rtType);
86+
auto dims = to_vector(ll.getOutDimNames());
87+
auto llReg = ll.sublayout({kReg}, dims);
88+
auto inv = ll.pseudoinvert();
89+
auto invReg = inv.sublayout(dims, {kReg});
90+
auto bases_inv = invReg.getBases();
91+
for (auto [c, d] : llvm::zip(constancy, dims)) {
92+
assert(llvm::isPowerOf2_32(c));
93+
for (int i = 0; i < llvm::Log2_32(c); i++) {
94+
bases_inv[d][i] = {0};
16195
}
162-
dedupResultVals.push_back(resultVals[dedup_idx]);
16396
}
164-
165-
return dedupResultVals;
97+
auto invBroadcast =
98+
LinearLayout(bases_inv, invReg.getOutDims(), /*isSurjective=*/false);
99+
auto cvt = llReg.compose(invBroadcast);
100+
101+
// Deduplicate the result values
102+
SmallVector<Value> outVals(resultVals.size());
103+
for (int i = 0; i < outVals.size(); i++) {
104+
auto srcIdx = cvt.apply({{kReg, i}}).begin()->second;
105+
outVals[i] = resultVals[srcIdx];
106+
}
107+
return outVals;
166108
}
167109
LogicalResult
168110
matchAndRewrite(SourceOp op, OpAdaptor adaptor,

test/Conversion/amd/dedup-by-constancy.mlir

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s
22

33
// CHECK-LABEL: dedup_by_constancy_mfma
4-
// CHECK-COUNT-4: llvm.icmp "slt"
4+
// CHECK-COUNT-2: llvm.icmp "slt"
55
// CHECK-NOT: llvm.icmp "slt"
6-
// Here is why we expect exactly 4 icmp:
76
// For a 32x32 tensor A with mfma layout, each thread holds 16 elements, which are divided
87
// into 4 groups. E.g. thread 0 holds elements A[0:3,0], A[8:11,0], A[16:19,0], and A[24:27,0].
98
// In this example, constancy of the tensor is 16 for dim 0, meaning A[0:15,0] have same values
109
// and A[16:31,0] have same values. Therefore, for thread 0, the first 8 elements are duplicated
11-
// and the last 8 elements are duplicated. Ideally, thread 0 only needs two icmp, one for the
12-
// first 8 elements and the other for the last 8 elements. In practice, the dedup analysis
13-
// only allows duplication within each group of 4 elemnets. Therefore, we expect 4 icmp, one
14-
// for each group of 4 elements.
15-
// In the future, we can reduce the icmp to 2 in such case.
10+
// and the last 8 elements are duplicated.
1611
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 1], instrShape = [32, 32, 8], isTransposed = false}>
1712
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
1813
tt.func public @dedup_by_constancy_mfma(%arg0: i32 {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {

0 commit comments

Comments
 (0)