@@ -95,6 +95,53 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
95
95
StringAttr::get (op->getContext (), libpath));
96
96
return ret;
97
97
}
98
+
99
+ Value matrixVectorProd (TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100
+ assert (A.getNumInDims () == 1 );
101
+ assert (A.getNumOutDims () == 1 );
102
+ auto flatten = [](const std::vector<std::vector<int32_t >> &matrix) {
103
+ SmallVector<int32_t > ret;
104
+ for (const auto &row : matrix) {
105
+ ret.push_back (row[0 ]);
106
+ }
107
+ return ret;
108
+ };
109
+ auto nCol = A.getTotalInDimSizeLog2 ();
110
+ auto nRow = A.getTotalOutDimSizeLog2 ();
111
+ SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
112
+ assert (matrix.size () == nCol);
113
+ // We iterate the matrix following the diagonals
114
+ // The idea here is that we want to generate code of the form:
115
+ // \xor_i (x & mask_i) << s_i
116
+ // where s_i may by positive or negative (left or right shift)
117
+ // The hope here (and we see it in codegen) is that LLVM can turn
118
+ // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119
+ // Get the i-th diagonal
120
+ auto getMask = [&](int i) {
121
+ uint32_t mask = 0 ;
122
+ int row = i < 0 ? -i : 0 ;
123
+ int col = i < 0 ? 0 : i;
124
+ while (row < nRow && col < nCol) {
125
+ uint32_t bitValue = (matrix[col] >> row) & 1u ;
126
+ mask |= bitValue << col;
127
+ ++row;
128
+ ++col;
129
+ }
130
+ return mask;
131
+ };
132
+
133
+ Value ret = b.i32_val (0 );
134
+ for (int i = -nRow + 1 ; i < nCol; i++) {
135
+ auto mask = getMask (i);
136
+ if (mask == 0 )
137
+ continue ;
138
+ auto masked = b.and_ (x, b.i32_val (mask));
139
+ ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
140
+ : Value (b.shl (masked, b.i32_val (-i))));
141
+ }
142
+ return ret;
143
+ }
144
+
98
145
} // namespace triton::gpu
99
146
100
147
SmallVector<std::pair<StringAttr, Value>>
@@ -115,12 +162,14 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
115
162
116
163
// Manually constant-fold the layout where possible.
117
164
SmallVector<std::pair<StringAttr, int32_t >> constantIns;
165
+ SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
118
166
for (auto [inDimName, idx] : indices) {
119
167
if (auto constant = idx.getDefiningOp <LLVM::ConstantOp>()) {
120
168
constantIns.push_back (
121
169
{inDimName, cast<IntegerAttr>(constant.getValue ()).getInt ()});
122
170
} else {
123
171
constantIns.push_back ({inDimName, 0 });
172
+ nonConstantIns.push_back ({inDimName, idx});
124
173
}
125
174
}
126
175
SmallVector<int32_t > constantComponent =
@@ -134,6 +183,24 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
134
183
else
135
184
outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
136
185
}
186
+ // Happy path: Only one output.
187
+ if (outIndices.size () == 1 ) {
188
+ SmallVector<StringAttr> inDimNames;
189
+ // Concatenate input
190
+ Value x = b.i32_val (0 );
191
+ int shift = 0 ;
192
+ for (auto [inDimName, idx] : nonConstantIns) {
193
+ inDimNames.push_back (inDimName);
194
+ x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
195
+ shift += layout.getInDimSizeLog2 (inDimName);
196
+ }
197
+ // Flatten ins
198
+ auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
199
+ matrix = matrix.flattenIns ();
200
+ auto out = triton::gpu::matrixVectorProd (b, matrix, x);
201
+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
202
+ return outIndices;
203
+ }
137
204
138
205
for (auto [inDimName, idx] : indices) {
139
206
if (idx.getDefiningOp <LLVM::ConstantOp>()) {
0 commit comments