@@ -159,39 +159,54 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
159159 SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
160160 assert (matrix.size () == nCol);
161161
162- // We iterate the matrix following the diagonals
163- // The idea here is that we want to generate code of the form:
164- // \xor_i (x & mask_i) << s_i
165- // where s_i may by positive or negative (left or right shift)
166- // The hope here (and we see it in codegen) is that LLVM can turn
167- // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
168- // Get the i-th diagonal
169- auto getMask = [&](int i) {
162+ // Row-wise popcount to detect rows that appear exactly once across columns.
163+ uint32_t rowsUnique = 0 ;
164+ {
165+ SmallVector<int > rowPopCnt (nRow, 0 );
166+ for (int c = 0 ; c < nCol; ++c) {
167+ uint32_t colBits = matrix[c];
168+ for (int r = 0 ; r < nRow; ++r) {
169+ if (colBits & (1u << r))
170+ ++rowPopCnt[r];
171+ }
172+ }
173+ for (int r = 0 ; r < nRow; ++r) {
174+ if (rowPopCnt[r] == 1 )
175+ rowsUnique |= 1u << r;
176+ }
177+ }
178+
179+ // We iterate the matrix following the diagonals and build
180+ // (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
181+ // then XOR everything else. This tends to encourage mad.lo codegen.
182+ auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t , bool > {
170183 uint32_t mask = 0 ;
171184 int row = i < 0 ? -i : 0 ;
172185 int col = i < 0 ? 0 : i;
186+ bool allRowsUnique = true ;
173187 while (row < nRow && col < nCol) {
174188 uint32_t bitValue = (matrix[col] >> row) & 1u ;
175189 mask |= bitValue << col;
190+ allRowsUnique &= ((rowsUnique >> row) & 1u ) == 1u ;
176191 ++row;
177192 ++col;
178193 }
179- return mask;
194+ return { mask, allRowsUnique} ;
180195 };
181196
182197 uint32_t explicitCols = 0 ;
183198
184199 {
185200 SmallVector<uint32_t > masks;
186201 for (int i = -nRow + 1 ; i < nCol; i++) {
187- masks.push_back (getMask (i ));
202+ masks.push_back (std::get< 0 >( getMaskAndAllRowsUnique (i) ));
188203 }
189204 bool reachedFixedPoint = false ;
190205 while (!reachedFixedPoint) {
191206 reachedFixedPoint = true ;
192207 for (uint32_t m : masks) {
193208 uint32_t c = m & ~explicitCols;
194- if ((c != 0 ) && ((c & (c - 1 )) == 0 )) {
209+ if (llvm::isPowerOf2_32 (c )) {
195210 // found a single-element diagonal
196211 explicitCols |= c;
197212 reachedFixedPoint = false ;
@@ -201,14 +216,21 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
201216 }
202217
203218 // handle any diagonals that have survived
204- Value ret = b.i32_val (0 );
219+ SmallVector<Value> ors;
220+ SmallVector<Value> xors;
205221 for (int i = -nRow + 1 ; i < nCol; i++) {
206- auto mask = getMask (i) & ~explicitCols;
222+ auto [mask, allRowsUnique] = getMaskAndAllRowsUnique (i);
223+ mask &= ~explicitCols;
207224 if (mask == 0 )
208225 continue ;
209226 auto masked = b.and_ (x, b.i32_val (mask));
210- ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
211- : Value (b.shl (masked, b.i32_val (-i))));
227+ auto shifted = i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
228+ : Value (b.shl (masked, b.i32_val (-i)));
229+ if (allRowsUnique) {
230+ ors.push_back (shifted);
231+ } else {
232+ xors.push_back (shifted);
233+ }
212234 }
213235
214236 // handle any explicit columns:
@@ -220,10 +242,35 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
220242 int32_t basis = matrix[i];
221243 if (basis == 0 )
222244 continue ;
223- ret = b.xor_ (ret, b.select (bit_is_zero, zero, b.i32_val (basis)));
245+ auto select = b.select (bit_is_zero, zero, b.i32_val (basis));
246+ if ((rowsUnique & basis) == basis) {
247+ ors.push_back (select);
248+ } else {
249+ xors.push_back (select);
250+ }
224251 }
225252 }
226- return ret;
253+
254+ auto treeReduce = [&](SmallVector<Value> &terms,
255+ std::function<Value (Value, Value)> op) -> Value {
256+ if (terms.empty ())
257+ return b.i32_val (0 );
258+ while (terms.size () > 1 ) {
259+ SmallVector<Value> next;
260+ for (size_t i = 0 ; i + 1 < terms.size (); i += 2 )
261+ next.push_back (op (terms[i], terms[i + 1 ]));
262+ if (terms.size () % 2 == 1 )
263+ next.push_back (terms.back ());
264+ terms = std::move (next);
265+ }
266+ return terms[0 ];
267+ };
268+
269+ auto orPart = treeReduce (
270+ ors, [&b](Value x, Value y) { return b.or_ (x, y, /* disjoint=*/ true ); });
271+ auto xorPart =
272+ treeReduce (xors, [&b](Value x, Value y) { return b.xor_ (x, y); });
273+ return b.or_ (orPart, xorPart, /* disjoint=*/ true );
227274}
228275
229276} // namespace triton::gpu
0 commit comments