@@ -95,6 +95,53 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
95
95
StringAttr::get (op->getContext (), libpath));
96
96
return ret;
97
97
}
98
+
99
+ Value matrixVectorProd (TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100
+ assert (A.getNumInDims () == 1 );
101
+ assert (A.getNumOutDims () == 1 );
102
+ auto flatten = [](const std::vector<std::vector<int32_t >> &matrix) {
103
+ SmallVector<int32_t > ret;
104
+ for (const auto &row : matrix) {
105
+ ret.push_back (row[0 ]);
106
+ }
107
+ return ret;
108
+ };
109
+ auto nCol = A.getTotalInDimSizeLog2 ();
110
+ auto nRow = A.getTotalOutDimSizeLog2 ();
111
+ SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
112
+ assert (matrix.size () == nCol);
113
+ // We iterate the matrix following the diagonals
114
+ // The idea here is that we want to generate code of the form:
115
+ // \xor_i (x & mask_i) << s_i
116
+ // where s_i may by positive or negative (left or right shift)
117
+ // The hope here (and we see it in codegen) is that LLVM can turn
118
+ // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119
+ // Get the i-th diagonal
120
+ auto getMask = [&](int i) {
121
+ uint32_t mask = 0 ;
122
+ int row = i < 0 ? -i : 0 ;
123
+ int col = i < 0 ? 0 : i;
124
+ while (row < nRow && col < nCol) {
125
+ uint32_t bitValue = (matrix[col] >> row) & 1u ;
126
+ mask |= bitValue << col;
127
+ ++row;
128
+ ++col;
129
+ }
130
+ return mask;
131
+ };
132
+
133
+ Value ret = b.i32_val (0 );
134
+ for (int i = -nRow + 1 ; i < nCol; i++) {
135
+ auto mask = getMask (i);
136
+ if (mask == 0 )
137
+ continue ;
138
+ auto masked = b.and_ (x, b.i32_val (mask));
139
+ ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
140
+ : Value (b.shl (masked, b.i32_val (-i))));
141
+ }
142
+ return ret;
143
+ }
144
+
98
145
} // namespace triton::gpu
99
146
100
147
SmallVector<std::pair<StringAttr, Value>>
@@ -117,12 +164,14 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
117
164
118
165
// Manually constant-fold the layout where possible.
119
166
SmallVector<std::pair<StringAttr, int32_t >> constantIns;
167
+ SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
120
168
for (auto [inDimName, idx] : indices) {
121
169
if (auto constant = idx.getDefiningOp <LLVM::ConstantOp>()) {
122
170
constantIns.push_back (
123
171
{inDimName, cast<IntegerAttr>(constant.getValue ()).getInt ()});
124
172
} else {
125
173
constantIns.push_back ({inDimName, 0 });
174
+ nonConstantIns.push_back ({inDimName, idx});
126
175
}
127
176
}
128
177
SmallVector<int32_t > constantComponent =
@@ -136,6 +185,28 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
136
185
else
137
186
outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
138
187
}
188
+ // Happy path: Only one output.
189
+ if (outIndices.size () == 1 ) {
190
+ SmallVector<StringAttr> inDimNames;
191
+ // Concatenate input
192
+ Value x = b.i32_val (0 );
193
+ int shift = 0 ;
194
+ for (auto orderedName : layout.getInDimNames ()) {
195
+ for (auto [inDimName, idx] : nonConstantIns) {
196
+ if (orderedName == inDimName) {
197
+ inDimNames.push_back (inDimName);
198
+ x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
199
+ shift += layout.getInDimSizeLog2 (inDimName);
200
+ }
201
+ }
202
+ }
203
+ // Flatten ins
204
+ auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
205
+ matrix = matrix.flattenIns ();
206
+ auto out = triton::gpu::matrixVectorProd (b, matrix, x);
207
+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
208
+ return outIndices;
209
+ }
139
210
140
211
for (auto [inDimName, idx] : indices) {
141
212
if (idx.getDefiningOp <LLVM::ConstantOp>()) {
0 commit comments