@@ -158,54 +158,39 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
158158 SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
159159 assert (matrix.size () == nCol);
160160
161- // Row-wise popcount to detect rows that appear exactly once across columns.
162- uint32_t rowsUnique = 0 ;
163- {
164- SmallVector<int > rowPopCnt (nRow, 0 );
165- for (int c = 0 ; c < nCol; ++c) {
166- uint32_t colBits = matrix[c];
167- for (int r = 0 ; r < nRow; ++r) {
168- if (colBits & (1u << r))
169- ++rowPopCnt[r];
170- }
171- }
172- for (int r = 0 ; r < nRow; ++r) {
173- if (rowPopCnt[r] == 1 )
174- rowsUnique |= 1u << r;
175- }
176- }
177-
178- // We iterate the matrix following the diagonals and build
179- // (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
180- // then XOR everything else. This tends to encourage mad.lo codegen.
181- auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t , bool > {
161+ // We iterate the matrix following the diagonals
162+ // The idea here is that we want to generate code of the form:
163+ // \xor_i (x & mask_i) << s_i
164+ // where s_i may by positive or negative (left or right shift)
165+ // The hope here (and we see it in codegen) is that LLVM can turn
166+ // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
167+ // Get the i-th diagonal
168+ auto getMask = [&](int i) {
182169 uint32_t mask = 0 ;
183170 int row = i < 0 ? -i : 0 ;
184171 int col = i < 0 ? 0 : i;
185- bool allRowsUnique = true ;
186172 while (row < nRow && col < nCol) {
187173 uint32_t bitValue = (matrix[col] >> row) & 1u ;
188174 mask |= bitValue << col;
189- allRowsUnique &= ((rowsUnique >> row) & 1u ) == 1u ;
190175 ++row;
191176 ++col;
192177 }
193- return { mask, allRowsUnique} ;
178+ return mask;
194179 };
195180
196181 uint32_t explicitCols = 0 ;
197182
198183 {
199184 SmallVector<uint32_t > masks;
200185 for (int i = -nRow + 1 ; i < nCol; i++) {
201- masks.push_back (std::get< 0 >( getMaskAndAllRowsUnique (i) ));
186+ masks.push_back (getMask (i ));
202187 }
203188 bool reachedFixedPoint = false ;
204189 while (!reachedFixedPoint) {
205190 reachedFixedPoint = true ;
206191 for (uint32_t m : masks) {
207192 uint32_t c = m & ~explicitCols;
208- if (llvm::isPowerOf2_32 (c )) {
193+ if ((c != 0 ) && ((c & (c - 1 )) == 0 )) {
209194 // found a single-element diagonal
210195 explicitCols |= c;
211196 reachedFixedPoint = false ;
@@ -215,21 +200,14 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
215200 }
216201
217202 // handle any diagonals that have survived
218- SmallVector<Value> ors;
219- SmallVector<Value> xors;
203+ Value ret = b.i32_val (0 );
220204 for (int i = -nRow + 1 ; i < nCol; i++) {
221- auto [mask, allRowsUnique] = getMaskAndAllRowsUnique (i);
222- mask &= ~explicitCols;
205+ auto mask = getMask (i) & ~explicitCols;
223206 if (mask == 0 )
224207 continue ;
225208 auto masked = b.and_ (x, b.i32_val (mask));
226- auto shifted = i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
227- : Value (b.shl (masked, b.i32_val (-i)));
228- if (allRowsUnique) {
229- ors.push_back (shifted);
230- } else {
231- xors.push_back (shifted);
232- }
209+ ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
210+ : Value (b.shl (masked, b.i32_val (-i))));
233211 }
234212
235213 // handle any explicit columns:
@@ -241,35 +219,10 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
241219 int32_t basis = matrix[i];
242220 if (basis == 0 )
243221 continue ;
244- auto select = b.select (bit_is_zero, zero, b.i32_val (basis));
245- if ((rowsUnique & basis) == basis) {
246- ors.push_back (select);
247- } else {
248- xors.push_back (select);
249- }
222+ ret = b.xor_ (ret, b.select (bit_is_zero, zero, b.i32_val (basis)));
250223 }
251224 }
252-
253- auto treeReduce = [&](SmallVector<Value> &terms,
254- std::function<Value (Value, Value)> op) -> Value {
255- if (terms.empty ())
256- return b.i32_val (0 );
257- while (terms.size () > 1 ) {
258- SmallVector<Value> next;
259- for (size_t i = 0 ; i + 1 < terms.size (); i += 2 )
260- next.push_back (op (terms[i], terms[i + 1 ]));
261- if (terms.size () % 2 == 1 )
262- next.push_back (terms.back ());
263- terms = std::move (next);
264- }
265- return terms[0 ];
266- };
267-
268- auto orPart = treeReduce (
269- ors, [&b](Value x, Value y) { return b.or_ (x, y, /* disjoint=*/ true ); });
270- auto xorPart =
271- treeReduce (xors, [&b](Value x, Value y) { return b.xor_ (x, y); });
272- return b.or_ (orPart, xorPart, /* disjoint=*/ true );
225+ return ret;
273226}
274227
275228} // namespace triton::gpu
0 commit comments