@@ -158,39 +158,54 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
158158 SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
159159 assert (matrix.size () == nCol);
160160
161- // We iterate the matrix following the diagonals
162- // The idea here is that we want to generate code of the form:
163- // \xor_i (x & mask_i) << s_i
164- // where s_i may by positive or negative (left or right shift)
165- // The hope here (and we see it in codegen) is that LLVM can turn
166- // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
167- // Get the i-th diagonal
168- auto getMask = [&](int i) {
161+ // Row-wise popcount to detect rows that appear exactly once across columns.
162+ uint32_t rowsUnique = 0 ;
163+ {
164+ SmallVector<int > rowPopCnt (nRow, 0 );
165+ for (int c = 0 ; c < nCol; ++c) {
166+ uint32_t colBits = matrix[c];
167+ for (int r = 0 ; r < nRow; ++r) {
168+ if (colBits & (1u << r))
169+ ++rowPopCnt[r];
170+ }
171+ }
172+ for (int r = 0 ; r < nRow; ++r) {
173+ if (rowPopCnt[r] == 1 )
174+ rowsUnique |= 1u << r;
175+ }
176+ }
177+
178+ // We iterate the matrix following the diagonals and build
179+ // (x & mask_i) << s_i terms. Prefer OR for diagonals whose rows are unique,
180+ // then XOR everything else. This tends to encourage mad.lo codegen.
181+ auto getMaskAndAllRowsUnique = [&](int i) -> std::pair<uint32_t , bool > {
169182 uint32_t mask = 0 ;
170183 int row = i < 0 ? -i : 0 ;
171184 int col = i < 0 ? 0 : i;
185+ bool allRowsUnique = true ;
172186 while (row < nRow && col < nCol) {
173187 uint32_t bitValue = (matrix[col] >> row) & 1u ;
174188 mask |= bitValue << col;
189+ allRowsUnique &= ((rowsUnique >> row) & 1u ) == 1u ;
175190 ++row;
176191 ++col;
177192 }
178- return mask;
193+ return { mask, allRowsUnique} ;
179194 };
180195
181196 uint32_t explicitCols = 0 ;
182197
183198 {
184199 SmallVector<uint32_t > masks;
185200 for (int i = -nRow + 1 ; i < nCol; i++) {
186- masks.push_back (getMask (i ));
201+ masks.push_back (std::get< 0 >( getMaskAndAllRowsUnique (i) ));
187202 }
188203 bool reachedFixedPoint = false ;
189204 while (!reachedFixedPoint) {
190205 reachedFixedPoint = true ;
191206 for (uint32_t m : masks) {
192207 uint32_t c = m & ~explicitCols;
193- if ((c != 0 ) && ((c & (c - 1 )) == 0 )) {
208+ if (llvm::isPowerOf2_32 (c )) {
194209 // found a single-element diagonal
195210 explicitCols |= c;
196211 reachedFixedPoint = false ;
@@ -200,14 +215,21 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
200215 }
201216
202217 // handle any diagonals that have survived
203- Value ret = b.i32_val (0 );
218+ SmallVector<Value> ors;
219+ SmallVector<Value> xors;
204220 for (int i = -nRow + 1 ; i < nCol; i++) {
205- auto mask = getMask (i) & ~explicitCols;
221+ auto [mask, allRowsUnique] = getMaskAndAllRowsUnique (i);
222+ mask &= ~explicitCols;
206223 if (mask == 0 )
207224 continue ;
208225 auto masked = b.and_ (x, b.i32_val (mask));
209- ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
210- : Value (b.shl (masked, b.i32_val (-i))));
226+ auto shifted = i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
227+ : Value (b.shl (masked, b.i32_val (-i)));
228+ if (allRowsUnique) {
229+ ors.push_back (shifted);
230+ } else {
231+ xors.push_back (shifted);
232+ }
211233 }
212234
213235 // handle any explicit columns:
@@ -219,10 +241,35 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
219241 int32_t basis = matrix[i];
220242 if (basis == 0 )
221243 continue ;
222- ret = b.xor_ (ret, b.select (bit_is_zero, zero, b.i32_val (basis)));
244+ auto select = b.select (bit_is_zero, zero, b.i32_val (basis));
245+ if ((rowsUnique & basis) == basis) {
246+ ors.push_back (select);
247+ } else {
248+ xors.push_back (select);
249+ }
223250 }
224251 }
225- return ret;
252+
253+ auto treeReduce = [&](SmallVector<Value> &terms,
254+ std::function<Value (Value, Value)> op) -> Value {
255+ if (terms.empty ())
256+ return b.i32_val (0 );
257+ while (terms.size () > 1 ) {
258+ SmallVector<Value> next;
259+ for (size_t i = 0 ; i + 1 < terms.size (); i += 2 )
260+ next.push_back (op (terms[i], terms[i + 1 ]));
261+ if (terms.size () % 2 == 1 )
262+ next.push_back (terms.back ());
263+ terms = std::move (next);
264+ }
265+ return terms[0 ];
266+ };
267+
268+ auto orPart = treeReduce (
269+ ors, [&b](Value x, Value y) { return b.or_ (x, y, /* disjoint=*/ true ); });
270+ auto xorPart =
271+ treeReduce (xors, [&b](Value x, Value y) { return b.xor_ (x, y); });
272+ return b.or_ (orPart, xorPart, /* disjoint=*/ true );
226273}
227274
228275} // namespace triton::gpu
@@ -542,18 +589,20 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
542589 return unpackLLVector (loc, valsVec, rewriter);
543590 }
544591 };
592+ auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
545593 return lowerLdSt (loc, ctx, cvt, valsArray, llvmElemTy, smemBase,
546- calcPaddedOffset, affineOffset, maskSpanAffineOffset,
547- rewriter, targetInfo, {}, emitLdSt);
594+ calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId,
595+ warpId, rewriter, targetInfo, {}, emitLdSt);
548596}
549597
550598SmallVector<Value> lowerLdSt (
551599 Location loc, MLIRContext *ctx, LinearLayout cvt,
552600 ArrayRef<Value> valsArray, // Input for store, output for load
553601 Type llvmElemTy, Value smemBase,
554602 std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
555- uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
556- const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
603+ uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
604+ RewriterBase &rewriter, const TargetInfoBase &targetInfo,
605+ std::optional<int> maybeMaxVecElems,
557606 std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
558607 Value, int , VectorType)>
559608 lowerInst) {
@@ -599,7 +648,6 @@ SmallVector<Value> lowerLdSt(
599648 zerosLike (LinearLayout::identity1D (bitwidth / 8 , kReg , kOffset ));
600649 auto i8AddrLayout = i8Tile * addrLayout;
601650
602- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
603651 auto regBaseI8 =
604652 applyLinearLayout (
605653 loc, rewriter, i8AddrLayout,
@@ -2022,16 +2070,17 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
20222070 };
20232071
20242072 auto noPaddingOffset = [](Value v) { return v; };
2073+ auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
20252074 lowerLdSt (loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20262075 /* calcPaddedOffset=*/ noPaddingOffset, /* affineOffset=*/ b.i32_val (0 ),
2027- /* maskSpanAffineOffset=*/ 0 , rewriter, targetInfo,
2076+ /* maskSpanAffineOffset=*/ 0 , laneId, warpId, rewriter, targetInfo,
20282077 /* maybeMaxVecElems=*/ {}, emitSt);
20292078 b.barrier ();
20302079 resultVals = lowerLdSt (loc, ctx, dstLayout, resultVals, valueElemTy, smemBase,
20312080 /* calcPaddedOffset=*/ noPaddingOffset,
20322081 /* affineOffset=*/ b.i32_val (0 ),
2033- /* maskSpanAffineOffset=*/ 0 , rewriter, targetInfo ,
2034- /* maybeMaxVecElems=*/ {}, emitLd);
2082+ /* maskSpanAffineOffset=*/ 0 , laneId, warpId, rewriter ,
2083+ targetInfo, /* maybeMaxVecElems=*/ {}, emitLd);
20352084
20362085 // Create the result struct and replace the operation
20372086 Value resultStruct =
0 commit comments