Skip to content

Commit 4fe73e1

Browse files
authored
[LAYOUTS] Use columns as well as diagonals (#7403)
This improves two kernel benchmarks on Blackwell and is completely neutral on Hopper (and there are no regressions whatsoever)
1 parent 9a01fe9 commit 4fe73e1

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
113113
auto nRow = A.getTotalOutDimSizeLog2();
114114
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
115115
assert(matrix.size() == nCol);
116+
116117
// We iterate the matrix following the diagonals
117118
// The idea here is that we want to generate code of the form:
118119
// \xor_i (x & mask_i) << s_i
@@ -133,15 +134,50 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
133134
return mask;
134135
};
135136

137+
uint32_t explicitCols = 0;
138+
139+
{
140+
SmallVector<uint32_t> masks;
141+
for (int i = -nRow + 1; i < nCol; i++) {
142+
masks.push_back(getMask(i));
143+
}
144+
bool reachedFixedPoint = false;
145+
while (!reachedFixedPoint) {
146+
reachedFixedPoint = true;
147+
for (uint32_t m : masks) {
148+
uint32_t c = m & ~explicitCols;
149+
if ((c != 0) && ((c & (c - 1)) == 0)) {
150+
// found a single-element diagonal
151+
explicitCols |= c;
152+
reachedFixedPoint = false;
153+
}
154+
}
155+
}
156+
}
157+
158+
// handle any diagonals that have survived
136159
Value ret = b.i32_val(0);
137160
for (int i = -nRow + 1; i < nCol; i++) {
138-
auto mask = getMask(i);
161+
auto mask = getMask(i) & ~explicitCols;
139162
if (mask == 0)
140163
continue;
141164
auto masked = b.and_(x, b.i32_val(mask));
142165
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
143166
: Value(b.shl(masked, b.i32_val(-i))));
144167
}
168+
169+
// handle any explicit columns:
170+
Value zero = b.i32_val(0);
171+
for (int i = 0; i < nCol; i++) {
172+
if ((explicitCols >> i) & 1) {
173+
Value bit = b.and_(x, b.i32_val(1 << i));
174+
Value bit_is_zero = b.icmp_eq(bit, zero);
175+
int32_t basis = matrix[i];
176+
if (basis == 0)
177+
continue;
178+
ret = b.xor_(ret, b.select(bit_is_zero, zero, b.i32_val(basis)));
179+
}
180+
}
145181
return ret;
146182
}
147183

0 commit comments

Comments
 (0)