@@ -159,39 +159,54 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
159
159
SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
160
160
assert (matrix.size () == nCol);
161
161
162
- // We iterate the matrix following the diagonals
163
- // The idea here is that we want to generate code of the form:
164
- // \xor_i (x & mask_i) << s_i
165
- // where s_i may by positive or negative (left or right shift)
166
- // The hope here (and we see it in codegen) is that LLVM can turn
167
- // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
168
- // Get the i-th diagonal
169
- auto getMask = [&](int i) {
162
+ // Row-wise popcount to detect rows that appear exactly once across columns.
163
+ uint32_t rowsUnique = 0 ;
164
+ {
165
+ SmallVector<int > rowPopCnt (nRow, 0 );
166
+ for (int c = 0 ; c < nCol; ++c) {
167
+ uint32_t colBits = matrix[c];
168
+ for (int r = 0 ; r < nRow; ++r) {
169
+ if (colBits & (1u << r))
170
+ ++rowPopCnt[r];
171
+ }
172
+ }
173
+ for (int r = 0 ; r < nRow; ++r) {
174
+ if (rowPopCnt[r] == 1 )
175
+ rowsUnique |= 1u << r;
176
+ }
177
+ }
178
+
179
+ // We iterate the matrix following the diagonals and build
180
+ // (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
181
+ // then XOR everything else. This tends to encourage mad.lo codegen.
182
+ auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t , bool > {
170
183
uint32_t mask = 0 ;
171
184
int row = i < 0 ? -i : 0 ;
172
185
int col = i < 0 ? 0 : i;
186
+ bool allRowsUnique = true ;
173
187
while (row < nRow && col < nCol) {
174
188
uint32_t bitValue = (matrix[col] >> row) & 1u ;
175
189
mask |= bitValue << col;
190
+ allRowsUnique &= ((rowsUnique >> row) & 1u ) == 1u ;
176
191
++row;
177
192
++col;
178
193
}
179
- return mask;
194
+ return { mask, allRowsUnique} ;
180
195
};
181
196
182
197
uint32_t explicitCols = 0 ;
183
198
184
199
{
185
200
SmallVector<uint32_t > masks;
186
201
for (int i = -nRow + 1 ; i < nCol; i++) {
187
- masks.push_back (getMask (i ));
202
+ masks.push_back (std::get< 0 >( getMaskAndAllRowsUnique (i) ));
188
203
}
189
204
bool reachedFixedPoint = false ;
190
205
while (!reachedFixedPoint) {
191
206
reachedFixedPoint = true ;
192
207
for (uint32_t m : masks) {
193
208
uint32_t c = m & ~explicitCols;
194
- if ((c != 0 ) && ((c & (c - 1 )) == 0 )) {
209
+ if (llvm::isPowerOf2_32 (c )) {
195
210
// found a single-element diagonal
196
211
explicitCols |= c;
197
212
reachedFixedPoint = false ;
@@ -201,14 +216,21 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
201
216
}
202
217
203
218
// handle any diagonals that have survived
204
- Value ret = b.i32_val (0 );
219
+ SmallVector<Value> ors;
220
+ SmallVector<Value> xors;
205
221
for (int i = -nRow + 1 ; i < nCol; i++) {
206
- auto mask = getMask (i) & ~explicitCols;
222
+ auto [mask, allRowsUnique] = getMaskAndAllRowsUnique (i);
223
+ mask &= ~explicitCols;
207
224
if (mask == 0 )
208
225
continue ;
209
226
auto masked = b.and_ (x, b.i32_val (mask));
210
- ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
211
- : Value (b.shl (masked, b.i32_val (-i))));
227
+ auto shifted = i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
228
+ : Value (b.shl (masked, b.i32_val (-i)));
229
+ if (allRowsUnique) {
230
+ ors.push_back (shifted);
231
+ } else {
232
+ xors.push_back (shifted);
233
+ }
212
234
}
213
235
214
236
// handle any explicit columns:
@@ -220,10 +242,35 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
220
242
int32_t basis = matrix[i];
221
243
if (basis == 0 )
222
244
continue ;
223
- ret = b.xor_ (ret, b.select (bit_is_zero, zero, b.i32_val (basis)));
245
+ auto select = b.select (bit_is_zero, zero, b.i32_val (basis));
246
+ if ((rowsUnique & basis) == basis) {
247
+ ors.push_back (select);
248
+ } else {
249
+ xors.push_back (select);
250
+ }
224
251
}
225
252
}
226
- return ret;
253
+
254
+ auto treeReduce = [&](SmallVector<Value> &terms,
255
+ std::function<Value (Value, Value)> op) -> Value {
256
+ if (terms.empty ())
257
+ return b.i32_val (0 );
258
+ while (terms.size () > 1 ) {
259
+ SmallVector<Value> next;
260
+ for (size_t i = 0 ; i + 1 < terms.size (); i += 2 )
261
+ next.push_back (op (terms[i], terms[i + 1 ]));
262
+ if (terms.size () % 2 == 1 )
263
+ next.push_back (terms.back ());
264
+ terms = std::move (next);
265
+ }
266
+ return terms[0 ];
267
+ };
268
+
269
+ auto orPart = treeReduce (
270
+ ors, [&b](Value x, Value y) { return b.or_ (x, y, /* disjoint=*/ true ); });
271
+ auto xorPart =
272
+ treeReduce (xors, [&b](Value x, Value y) { return b.xor_ (x, y); });
273
+ return b.or_ (orPart, xorPart, /* disjoint=*/ true );
227
274
}
228
275
229
276
} // namespace triton::gpu
0 commit comments