Skip to content

Commit 99796f4

Browse files
Revert "[LLs] Tree-reduce the xor reduction in LLs codegen (#7816)"
This reverts commit 4da227c.
1 parent 2674dbb commit 99796f4

File tree

2 files changed

+18
-67
lines changed

2 files changed

+18
-67
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 17 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -158,54 +158,39 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
158158
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
159159
assert(matrix.size() == nCol);
160160

161-
// Row-wise popcount to detect rows that appear exactly once across columns.
162-
uint32_t rowsUnique = 0;
163-
{
164-
SmallVector<int> rowPopCnt(nRow, 0);
165-
for (int c = 0; c < nCol; ++c) {
166-
uint32_t colBits = matrix[c];
167-
for (int r = 0; r < nRow; ++r) {
168-
if (colBits & (1u << r))
169-
++rowPopCnt[r];
170-
}
171-
}
172-
for (int r = 0; r < nRow; ++r) {
173-
if (rowPopCnt[r] == 1)
174-
rowsUnique |= 1u << r;
175-
}
176-
}
177-
178-
// We iterate the matrix following the diagonals and build
179-
// (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
180-
// then XOR everything else. This tends to encourage mad.lo codegen.
181-
auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t, bool> {
161+
// We iterate the matrix following the diagonals
162+
// The idea here is that we want to generate code of the form:
163+
// \xor_i (x & mask_i) << s_i
164+
// where s_i may by positive or negative (left or right shift)
165+
// The hope here (and we see it in codegen) is that LLVM can turn
166+
// the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
167+
// Get the i-th diagonal
168+
auto getMask = [&](int i) {
182169
uint32_t mask = 0;
183170
int row = i < 0 ? -i : 0;
184171
int col = i < 0 ? 0 : i;
185-
bool allRowsUnique = true;
186172
while (row < nRow && col < nCol) {
187173
uint32_t bitValue = (matrix[col] >> row) & 1u;
188174
mask |= bitValue << col;
189-
allRowsUnique &= ((rowsUnique >> row) & 1u) == 1u;
190175
++row;
191176
++col;
192177
}
193-
return {mask, allRowsUnique};
178+
return mask;
194179
};
195180

196181
uint32_t explicitCols = 0;
197182

198183
{
199184
SmallVector<uint32_t> masks;
200185
for (int i = -nRow + 1; i < nCol; i++) {
201-
masks.push_back(std::get<0>(getMaskAndAllRowsUnique(i)));
186+
masks.push_back(getMask(i));
202187
}
203188
bool reachedFixedPoint = false;
204189
while (!reachedFixedPoint) {
205190
reachedFixedPoint = true;
206191
for (uint32_t m : masks) {
207192
uint32_t c = m & ~explicitCols;
208-
if (llvm::isPowerOf2_32(c)) {
193+
if ((c != 0) && ((c & (c - 1)) == 0)) {
209194
// found a single-element diagonal
210195
explicitCols |= c;
211196
reachedFixedPoint = false;
@@ -215,21 +200,14 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
215200
}
216201

217202
// handle any diagonals that have survived
218-
SmallVector<Value> ors;
219-
SmallVector<Value> xors;
203+
Value ret = b.i32_val(0);
220204
for (int i = -nRow + 1; i < nCol; i++) {
221-
auto [mask, allRowsUnique] = getMaskAndAllRowsUnique(i);
222-
mask &= ~explicitCols;
205+
auto mask = getMask(i) & ~explicitCols;
223206
if (mask == 0)
224207
continue;
225208
auto masked = b.and_(x, b.i32_val(mask));
226-
auto shifted = i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
227-
: Value(b.shl(masked, b.i32_val(-i)));
228-
if (allRowsUnique) {
229-
ors.push_back(shifted);
230-
} else {
231-
xors.push_back(shifted);
232-
}
209+
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
210+
: Value(b.shl(masked, b.i32_val(-i))));
233211
}
234212

235213
// handle any explicit columns:
@@ -241,35 +219,10 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
241219
int32_t basis = matrix[i];
242220
if (basis == 0)
243221
continue;
244-
auto select = b.select(bit_is_zero, zero, b.i32_val(basis));
245-
if ((rowsUnique & basis) == basis) {
246-
ors.push_back(select);
247-
} else {
248-
xors.push_back(select);
249-
}
222+
ret = b.xor_(ret, b.select(bit_is_zero, zero, b.i32_val(basis)));
250223
}
251224
}
252-
253-
auto treeReduce = [&](SmallVector<Value> &terms,
254-
std::function<Value(Value, Value)> op) -> Value {
255-
if (terms.empty())
256-
return b.i32_val(0);
257-
while (terms.size() > 1) {
258-
SmallVector<Value> next;
259-
for (size_t i = 0; i + 1 < terms.size(); i += 2)
260-
next.push_back(op(terms[i], terms[i + 1]));
261-
if (terms.size() % 2 == 1)
262-
next.push_back(terms.back());
263-
terms = std::move(next);
264-
}
265-
return terms[0];
266-
};
267-
268-
auto orPart = treeReduce(
269-
ors, [&b](Value x, Value y) { return b.or_(x, y, /*disjoint=*/true); });
270-
auto xorPart =
271-
treeReduce(xors, [&b](Value x, Value y) { return b.xor_(x, y); });
272-
return b.or_(orPart, xorPart, /*disjoint=*/true);
225+
return ret;
273226
}
274227

275228
} // namespace triton::gpu

test/Conversion/amd/convert_layout.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
1212

1313
// Part of offset computation generated by applyLinearLayout function
1414
// CHECK: [[SEL:%.*]]= llvm.select {{.*}}, {{.*}}, [[CST_128]]
15-
// CHECK-COUNT-3: llvm.or disjoint
16-
// CHECK-COUNT-2: llvm.xor
17-
// CHECK: [[OFFSET_0:%.*]] = llvm.or disjoint
15+
// CHECK: [[OFFSET_0:%.*]] = llvm.xor {{.*}}, [[SEL]]
1816
// CHECK: [[OFFSET_1:%.*]] = llvm.xor {{.*}}, [[OFFSET_0]] : i32
1917

2018
// Part of offset computation generated by lowerLdSt function after applyLinearLayout

0 commit comments

Comments
 (0)