@@ -113,6 +113,7 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
113
113
auto nRow = A.getTotalOutDimSizeLog2 ();
114
114
SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
115
115
assert (matrix.size () == nCol);
116
+
116
117
// We iterate the matrix following the diagonals
117
118
// The idea here is that we want to generate code of the form:
118
119
// \xor_i (x & mask_i) << s_i
@@ -133,15 +134,50 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
133
134
return mask;
134
135
};
135
136
137
+ uint32_t explicitCols = 0 ;
138
+
139
+ {
140
+ SmallVector<uint32_t > masks;
141
+ for (int i = -nRow + 1 ; i < nCol; i++) {
142
+ masks.push_back (getMask (i));
143
+ }
144
+ bool reachedFixedPoint = false ;
145
+ while (!reachedFixedPoint) {
146
+ reachedFixedPoint = true ;
147
+ for (uint32_t m : masks) {
148
+ uint32_t c = m & ~explicitCols;
149
+ if ((c != 0 ) && ((c & (c - 1 )) == 0 )) {
150
+ // found a single-element diagonal
151
+ explicitCols |= c;
152
+ reachedFixedPoint = false ;
153
+ }
154
+ }
155
+ }
156
+ }
157
+
158
+ // handle any diagonals that have survived
136
159
Value ret = b.i32_val (0 );
137
160
for (int i = -nRow + 1 ; i < nCol; i++) {
138
- auto mask = getMask (i);
161
+ auto mask = getMask (i) & ~explicitCols ;
139
162
if (mask == 0 )
140
163
continue ;
141
164
auto masked = b.and_ (x, b.i32_val (mask));
142
165
ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
143
166
: Value (b.shl (masked, b.i32_val (-i))));
144
167
}
168
+
169
+ // handle any explicit columns:
170
+ Value zero = b.i32_val (0 );
171
+ for (int i = 0 ; i < nCol; i++) {
172
+ if ((explicitCols >> i) & 1 ) {
173
+ Value bit = b.and_ (x, b.i32_val (1 << i));
174
+ Value bit_is_zero = b.icmp_eq (bit, zero);
175
+ int32_t basis = matrix[i];
176
+ if (basis == 0 )
177
+ continue ;
178
+ ret = b.xor_ (ret, b.select (bit_is_zero, zero, b.i32_val (basis)));
179
+ }
180
+ }
145
181
return ret;
146
182
}
147
183
0 commit comments