@@ -95,53 +95,6 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
95
95
StringAttr::get (op->getContext (), libpath));
96
96
return ret;
97
97
}
98
-
99
- Value matrixVectorProd (TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
100
- assert (A.getNumInDims () == 1 );
101
- assert (A.getNumOutDims () == 1 );
102
- auto flatten = [](const std::vector<std::vector<int32_t >> &matrix) {
103
- SmallVector<int32_t > ret;
104
- for (const auto &row : matrix) {
105
- ret.push_back (row[0 ]);
106
- }
107
- return ret;
108
- };
109
- auto nCol = A.getTotalInDimSizeLog2 ();
110
- auto nRow = A.getTotalOutDimSizeLog2 ();
111
- SmallVector<int32_t > matrix = flatten (A.getBases ().begin ()->second );
112
- assert (matrix.size () == nCol);
113
- // We iterate the matrix following the diagonals
114
- // The idea here is that we want to generate code of the form:
115
- // \xor_i (x & mask_i) << s_i
116
- // where s_i may by positive or negative (left or right shift)
117
- // The hope here (and we see it in codegen) is that LLVM can turn
118
- // the xor into a sum and then the sum + LHS/RHS can be fused into a mad.lo
119
- // Get the i-th diagonal
120
- auto getMask = [&](int i) {
121
- uint32_t mask = 0 ;
122
- int row = i < 0 ? -i : 0 ;
123
- int col = i < 0 ? 0 : i;
124
- while (row < nRow && col < nCol) {
125
- uint32_t bitValue = (matrix[col] >> row) & 1u ;
126
- mask |= bitValue << col;
127
- ++row;
128
- ++col;
129
- }
130
- return mask;
131
- };
132
-
133
- Value ret = b.i32_val (0 );
134
- for (int i = -nRow + 1 ; i < nCol; i++) {
135
- auto mask = getMask (i);
136
- if (mask == 0 )
137
- continue ;
138
- auto masked = b.and_ (x, b.i32_val (mask));
139
- ret = b.xor_ (ret, i >= 0 ? Value (b.lshr (masked, b.i32_val (i)))
140
- : Value (b.shl (masked, b.i32_val (-i))));
141
- }
142
- return ret;
143
- }
144
-
145
98
} // namespace triton::gpu
146
99
147
100
SmallVector<std::pair<StringAttr, Value>>
@@ -162,14 +115,12 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
162
115
163
116
// Manually constant-fold the layout where possible.
164
117
SmallVector<std::pair<StringAttr, int32_t >> constantIns;
165
- SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
166
118
for (auto [inDimName, idx] : indices) {
167
119
if (auto constant = idx.getDefiningOp <LLVM::ConstantOp>()) {
168
120
constantIns.push_back (
169
121
{inDimName, cast<IntegerAttr>(constant.getValue ()).getInt ()});
170
122
} else {
171
123
constantIns.push_back ({inDimName, 0 });
172
- nonConstantIns.push_back ({inDimName, idx});
173
124
}
174
125
}
175
126
SmallVector<int32_t > constantComponent =
@@ -183,24 +134,6 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
183
134
else
184
135
outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
185
136
}
186
- // Happy path: Only one output.
187
- if (outIndices.size () == 1 ) {
188
- SmallVector<StringAttr> inDimNames;
189
- // Concatenate input
190
- Value x = b.i32_val (0 );
191
- int shift = 0 ;
192
- for (auto [inDimName, idx] : nonConstantIns) {
193
- inDimNames.push_back (inDimName);
194
- x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
195
- shift += layout.getInDimSizeLog2 (inDimName);
196
- }
197
- // Flatten ins
198
- auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
199
- matrix = matrix.flattenIns ();
200
- auto out = triton::gpu::matrixVectorProd (b, matrix, x);
201
- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
202
- return outIndices;
203
- }
204
137
205
138
for (auto [inDimName, idx] : indices) {
206
139
if (idx.getDefiningOp <LLVM::ConstantOp>()) {
0 commit comments