@@ -158,39 +158,54 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
158158 SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
159159 assert (matrix.size () == nCol);
160160
161- // We iterate the matrix following the diagonals
162- // The idea here is that we want to generate code of the form:
163- // \xor_i (x & mask_i) << s_i
164- // where s_i may by positive or negative (left or right shift)
165- // The hope here (and we see it in codegen) is that LLVM can turn
166- // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
167- // Get the i-th diagonal
168- auto getMask = [&](int i) {
161+ // Row-wise popcount to detect rows that appear exactly once across columns.
162+ uint32_t rowsUnique = 0 ;
163+ {
164+ SmallVector<int > rowPopCnt (nRow, 0 );
165+ for (int c = 0 ; c < nCol; ++c) {
166+ uint32_t colBits = matrix[c];
167+ for (int r = 0 ; r < nRow; ++r) {
168+ if (colBits & (1u << r))
169+ ++rowPopCnt[r];
170+ }
171+ }
172+ for (int r = 0 ; r < nRow; ++r) {
173+ if (rowPopCnt[r] == 1 )
174+ rowsUnique |= 1u << r;
175+ }
176+ }
177+
178+ // We iterate the matrix following the diagonals and build
179+ // (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
180+ // then XOR everything else. This tends to encourage mad.lo codegen.
181+ auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t , bool > {
169182 uint32_t mask = 0 ;
170183 int row = i < 0 ? -i : 0 ;
171184 int col = i < 0 ? 0 : i;
185+ bool allRowsUnique = true ;
172186 while (row < nRow && col < nCol) {
173187 uint32_t bitValue = (matrix[col] >> row) & 1u ;
174188 mask |= bitValue << col;
189+ allRowsUnique &= ((rowsUnique >> row) & 1u ) == 1u ;
175190 ++row;
176191 ++col;
177192 }
178- return mask;
193+ return { mask, allRowsUnique} ;
179194 };
180195
181196 uint32_t explicitCols = 0 ;
182197
183198 {
184199 SmallVector<uint32_t > masks;
185200 for (int i = -nRow + 1 ; i < nCol; i++) {
186- masks.push_back (getMask (i ));
201+ masks.push_back (std::get< 0 >( getMaskAndAllRowsUnique (i) ));
187202 }
188203 bool reachedFixedPoint = false ;
189204 while (!reachedFixedPoint) {
190205 reachedFixedPoint = true ;
191206 for (uint32_t m : masks) {
192207 uint32_t c = m & ~explicitCols;
193- if ((c != 0 ) && ((c & (c - 1 )) == 0 )) {
208+ if (llvm::isPowerOf2_32 (c )) {
194209 // found a single-element diagonal
195210 explicitCols |= c;
196211 reachedFixedPoint = false ;
@@ -200,14 +215,21 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
200215 }
201216
202217 // handle any diagonals that have survived
203- Value ret = b.i32_val (0 );
218+ SmallVector<Value> ors;
219+ SmallVector<Value> xors;
204220 for (int i = -nRow + 1 ; i < nCol; i++) {
205- auto mask = getMask (i) & ~explicitCols;
221+ auto [mask, allRowsUnique] = getMaskAndAllRowsUnique (i);
222+ mask &= ~explicitCols;
206223 if (mask == 0 )
207224 continue ;
208225 auto masked = b.and_ (x, b.i32_val (mask));
209- ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
210- : Value (b.shl (masked, b.i32_val (-i))));
226+ auto shifted = i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
227+ : Value (b.shl (masked, b.i32_val (-i)));
228+ if (allRowsUnique) {
229+ ors.push_back (shifted);
230+ } else {
231+ xors.push_back (shifted);
232+ }
211233 }
212234
213235 // handle any explicit columns:
@@ -219,10 +241,35 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
219241 int32_t basis = matrix[i];
220242 if (basis == 0 )
221243 continue ;
222- ret = b.xor_ (ret, b.select (bit_is_zero, zero, b.i32_val (basis)));
244+ auto select = b.select (bit_is_zero, zero, b.i32_val (basis));
245+ if ((rowsUnique & basis) == basis) {
246+ ors.push_back (select);
247+ } else {
248+ xors.push_back (select);
249+ }
223250 }
224251 }
225- return ret;
252+
253+ auto treeReduce = [&](SmallVector<Value> &terms,
254+ std::function<Value (Value, Value)> op) -> Value {
255+ if (terms.empty ())
256+ return b.i32_val (0 );
257+ while (terms.size () > 1 ) {
258+ SmallVector<Value> next;
259+ for (size_t i = 0 ; i + 1 < terms.size (); i += 2 )
260+ next.push_back (op (terms[i], terms[i + 1 ]));
261+ if (terms.size () % 2 == 1 )
262+ next.push_back (terms.back ());
263+ terms = std::move (next);
264+ }
265+ return terms[0 ];
266+ };
267+
268+ auto orPart = treeReduce (
269+ ors, [&b](Value x, Value y) { return b.or_ (x, y, /* disjoint=*/ true ); });
270+ auto xorPart =
271+ treeReduce (xors, [&b](Value x, Value y) { return b.xor_ (x, y); });
272+ return b.or_ (orPart, xorPart, /* disjoint=*/ true );
226273}
227274
228275} // namespace triton::gpu
0 commit comments