Skip to content

Commit 4da227c

Browse files
authored
[LLs] Tree-reduce the xor reduction in LLs codegen (#7816)
This should allow better scheduling at the expense of slightly higher register pressure.
1 parent 66b231b commit 4da227c

File tree

2 files changed

+67
-18
lines changed

2 files changed

+67
-18
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -158,39 +158,54 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
158158
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
159159
assert(matrix.size() == nCol);
160160

161-
// We iterate the matrix following the diagonals
162-
// The idea here is that we want to generate code of the form:
163-
// \xor_i (x & mask_i) << s_i
164-
// where s_i may by positive or negative (left or right shift)
165-
// The hope here (and we see it in codegen) is that LLVM can turn
166-
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
167-
// Get the i-th diagonal
168-
auto getMask = [&](int i) {
161+
// Row-wise popcount to detect rows that appear exactly once across columns.
162+
uint32_t rowsUnique = 0;
163+
{
164+
SmallVector<int> rowPopCnt(nRow, 0);
165+
for (int c = 0; c < nCol; ++c) {
166+
uint32_t colBits = matrix[c];
167+
for (int r = 0; r < nRow; ++r) {
168+
if (colBits & (1u << r))
169+
++rowPopCnt[r];
170+
}
171+
}
172+
for (int r = 0; r < nRow; ++r) {
173+
if (rowPopCnt[r] == 1)
174+
rowsUnique |= 1u << r;
175+
}
176+
}
177+
178+
// We iterate the matrix following the diagonals and build
179+
// (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
180+
// then XOR everything else. This tends to encourage mad.lo codegen.
181+
auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t, bool> {
169182
uint32_t mask = 0;
170183
int row = i < 0 ? -i : 0;
171184
int col = i < 0 ? 0 : i;
185+
bool allRowsUnique = true;
172186
while (row < nRow && col < nCol) {
173187
uint32_t bitValue = (matrix[col] >> row) & 1u;
174188
mask |= bitValue << col;
189+
allRowsUnique &= ((rowsUnique >> row) & 1u) == 1u;
175190
++row;
176191
++col;
177192
}
178-
return mask;
193+
return {mask, allRowsUnique};
179194
};
180195

181196
uint32_t explicitCols = 0;
182197

183198
{
184199
SmallVector<uint32_t> masks;
185200
for (int i = -nRow + 1; i < nCol; i++) {
186-
masks.push_back(getMask(i));
201+
masks.push_back(std::get<0>(getMaskAndAllRowsUnique(i)));
187202
}
188203
bool reachedFixedPoint = false;
189204
while (!reachedFixedPoint) {
190205
reachedFixedPoint = true;
191206
for (uint32_t m : masks) {
192207
uint32_t c = m & ~explicitCols;
193-
if ((c != 0) && ((c & (c - 1)) == 0)) {
208+
if (llvm::isPowerOf2_32(c)) {
194209
// found a single-element diagonal
195210
explicitCols |= c;
196211
reachedFixedPoint = false;
@@ -200,14 +215,21 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
200215
}
201216

202217
// handle any diagonals that have survived
203-
Value ret = b.i32_val(0);
218+
SmallVector<Value> ors;
219+
SmallVector<Value> xors;
204220
for (int i = -nRow + 1; i < nCol; i++) {
205-
auto mask = getMask(i) & ~explicitCols;
221+
auto [mask, allRowsUnique] = getMaskAndAllRowsUnique(i);
222+
mask &= ~explicitCols;
206223
if (mask == 0)
207224
continue;
208225
auto masked = b.and_(x, b.i32_val(mask));
209-
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
210-
: Value(b.shl(masked, b.i32_val(-i))));
226+
auto shifted = i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
227+
: Value(b.shl(masked, b.i32_val(-i)));
228+
if (allRowsUnique) {
229+
ors.push_back(shifted);
230+
} else {
231+
xors.push_back(shifted);
232+
}
211233
}
212234

213235
// handle any explicit columns:
@@ -219,10 +241,35 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
219241
int32_t basis = matrix[i];
220242
if (basis == 0)
221243
continue;
222-
ret = b.xor_(ret, b.select(bit_is_zero, zero, b.i32_val(basis)));
244+
auto select = b.select(bit_is_zero, zero, b.i32_val(basis));
245+
if ((rowsUnique & basis) == basis) {
246+
ors.push_back(select);
247+
} else {
248+
xors.push_back(select);
249+
}
223250
}
224251
}
225-
return ret;
252+
253+
auto treeReduce = [&](SmallVector<Value> &terms,
254+
std::function<Value(Value, Value)> op) -> Value {
255+
if (terms.empty())
256+
return b.i32_val(0);
257+
while (terms.size() > 1) {
258+
SmallVector<Value> next;
259+
for (size_t i = 0; i + 1 < terms.size(); i += 2)
260+
next.push_back(op(terms[i], terms[i + 1]));
261+
if (terms.size() % 2 == 1)
262+
next.push_back(terms.back());
263+
terms = std::move(next);
264+
}
265+
return terms[0];
266+
};
267+
268+
auto orPart = treeReduce(
269+
ors, [&b](Value x, Value y) { return b.or_(x, y, /*disjoint=*/true); });
270+
auto xorPart =
271+
treeReduce(xors, [&b](Value x, Value y) { return b.xor_(x, y); });
272+
return b.or_(orPart, xorPart, /*disjoint=*/true);
226273
}
227274

228275
} // namespace triton::gpu

test/Conversion/amd/convert_layout.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
1212

1313
// Part of offset computation generated by applyLinearLayout function
1414
// CHECK: [[SEL:%.*]]= llvm.select {{.*}}, {{.*}}, [[CST_128]]
15-
// CHECK: [[OFFSET_0:%.*]] = llvm.xor {{.*}}, [[SEL]]
15+
// CHECK-COUNT-3: llvm.or disjoint
16+
// CHECK-COUNT-2: llvm.xor
17+
// CHECK: [[OFFSET_0:%.*]] = llvm.or disjoint
1618
// CHECK: [[OFFSET_1:%.*]] = llvm.xor {{.*}}, [[OFFSET_0]] : i32
1719

1820
// Part of offset computation generated by lowerLdSt function after applyLinearLayout

0 commit comments

Comments
 (0)