Skip to content

Commit 4ee3353

Browse files
Revert "Revert "[LLs] Tree-reduce the xor reduction in LLs codegen (#7816)"" (#4980)
This reverts commit 99796f4.
2 parents d607951 + 03259b0 commit 4ee3353

File tree

6 files changed

+700
-586
lines changed

6 files changed

+700
-586
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -159,39 +159,54 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
159159
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
160160
assert(matrix.size() == nCol);
161161

162-
// We iterate the matrix following the diagonals
163-
// The idea here is that we want to generate code of the form:
164-
// \xor_i (x & mask_i) << s_i
165-
// where s_i may by positive or negative (left or right shift)
166-
// The hope here (and we see it in codegen) is that LLVM can turn
167-
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
168-
// Get the i-th diagonal
169-
auto getMask = [&](int i) {
162+
// Row-wise popcount to detect rows that appear exactly once across columns.
163+
uint32_t rowsUnique = 0;
164+
{
165+
SmallVector<int> rowPopCnt(nRow, 0);
166+
for (int c = 0; c < nCol; ++c) {
167+
uint32_t colBits = matrix[c];
168+
for (int r = 0; r < nRow; ++r) {
169+
if (colBits & (1u << r))
170+
++rowPopCnt[r];
171+
}
172+
}
173+
for (int r = 0; r < nRow; ++r) {
174+
if (rowPopCnt[r] == 1)
175+
rowsUnique |= 1u << r;
176+
}
177+
}
178+
179+
// We iterate the matrix following the diagonals and build
180+
// (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
181+
// then XOR everything else. This tends to encourage mad.lo codegen.
182+
auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t, bool> {
170183
uint32_t mask = 0;
171184
int row = i < 0 ? -i : 0;
172185
int col = i < 0 ? 0 : i;
186+
bool allRowsUnique = true;
173187
while (row < nRow && col < nCol) {
174188
uint32_t bitValue = (matrix[col] >> row) & 1u;
175189
mask |= bitValue << col;
190+
allRowsUnique &= ((rowsUnique >> row) & 1u) == 1u;
176191
++row;
177192
++col;
178193
}
179-
return mask;
194+
return {mask, allRowsUnique};
180195
};
181196

182197
uint32_t explicitCols = 0;
183198

184199
{
185200
SmallVector<uint32_t> masks;
186201
for (int i = -nRow + 1; i < nCol; i++) {
187-
masks.push_back(getMask(i));
202+
masks.push_back(std::get<0>(getMaskAndAllRowsUnique(i)));
188203
}
189204
bool reachedFixedPoint = false;
190205
while (!reachedFixedPoint) {
191206
reachedFixedPoint = true;
192207
for (uint32_t m : masks) {
193208
uint32_t c = m & ~explicitCols;
194-
if ((c != 0) && ((c & (c - 1)) == 0)) {
209+
if (llvm::isPowerOf2_32(c)) {
195210
// found a single-element diagonal
196211
explicitCols |= c;
197212
reachedFixedPoint = false;
@@ -201,14 +216,21 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
201216
}
202217

203218
// handle any diagonals that have survived
204-
Value ret = b.i32_val(0);
219+
SmallVector<Value> ors;
220+
SmallVector<Value> xors;
205221
for (int i = -nRow + 1; i < nCol; i++) {
206-
auto mask = getMask(i) & ~explicitCols;
222+
auto [mask, allRowsUnique] = getMaskAndAllRowsUnique(i);
223+
mask &= ~explicitCols;
207224
if (mask == 0)
208225
continue;
209226
auto masked = b.and_(x, b.i32_val(mask));
210-
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
211-
: Value(b.shl(masked, b.i32_val(-i))));
227+
auto shifted = i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
228+
: Value(b.shl(masked, b.i32_val(-i)));
229+
if (allRowsUnique) {
230+
ors.push_back(shifted);
231+
} else {
232+
xors.push_back(shifted);
233+
}
212234
}
213235

214236
// handle any explicit columns:
@@ -220,10 +242,35 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
220242
int32_t basis = matrix[i];
221243
if (basis == 0)
222244
continue;
223-
ret = b.xor_(ret, b.select(bit_is_zero, zero, b.i32_val(basis)));
245+
auto select = b.select(bit_is_zero, zero, b.i32_val(basis));
246+
if ((rowsUnique & basis) == basis) {
247+
ors.push_back(select);
248+
} else {
249+
xors.push_back(select);
250+
}
224251
}
225252
}
226-
return ret;
253+
254+
auto treeReduce = [&](SmallVector<Value> &terms,
255+
std::function<Value(Value, Value)> op) -> Value {
256+
if (terms.empty())
257+
return b.i32_val(0);
258+
while (terms.size() > 1) {
259+
SmallVector<Value> next;
260+
for (size_t i = 0; i + 1 < terms.size(); i += 2)
261+
next.push_back(op(terms[i], terms[i + 1]));
262+
if (terms.size() % 2 == 1)
263+
next.push_back(terms.back());
264+
terms = std::move(next);
265+
}
266+
return terms[0];
267+
};
268+
269+
auto orPart = treeReduce(
270+
ors, [&b](Value x, Value y) { return b.or_(x, y, /*disjoint=*/true); });
271+
auto xorPart =
272+
treeReduce(xors, [&b](Value x, Value y) { return b.xor_(x, y); });
273+
return b.or_(orPart, xorPart, /*disjoint=*/true);
227274
}
228275

229276
} // namespace triton::gpu

test/Conversion/amd/convert_layout.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
1212

1313
// Part of offset computation generated by applyLinearLayout function
1414
// CHECK: [[SEL:%.*]]= llvm.select {{.*}}, {{.*}}, [[CST_128]]
15-
// CHECK: [[OFFSET_0:%.*]] = llvm.xor {{.*}}, [[SEL]]
15+
// CHECK-COUNT-3: llvm.or disjoint
16+
// CHECK-COUNT-2: llvm.xor
17+
// CHECK: [[OFFSET_0:%.*]] = llvm.or disjoint
1618
// CHECK: [[OFFSET_1:%.*]] = llvm.xor {{.*}}, [[OFFSET_0]] : i32
1719

1820
// Part of offset computation generated by lowerLdSt function after applyLinearLayout

0 commit comments

Comments
 (0)