@@ -151,6 +151,10 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
151151 auto b = TritonLLVMOpBuilder (loc, rewriter);
152152 assert (layout.getNumInDims () == indices.size ());
153153 assert (llvm::equal (layout.getInDimNames (), llvm::make_first_range (indices)));
154+ // Trivial layout
155+ if (layout.getNumOutDims () == 0 ) {
156+ return {};
157+ }
154158
155159 // This function can emit a lot of MLIR code, which ultimately makes
156160 // compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -164,62 +168,65 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
164168 SmallVector<std::pair<StringAttr, int32_t >> constantIns;
165169 SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
166170 for (auto [inDimName, idx] : indices) {
167- if ( auto constant = idx. getDefiningOp <LLVM::ConstantOp>()) {
168- constantIns. push_back (
169- {inDimName, cast<IntegerAttr>( constant.getValue ()). getInt ()});
171+ APInt constant;
172+ if ( matchPattern (idx, m_ConstantInt (&constant))) {
173+ constantIns. push_back ( {inDimName, constant.getSExtValue ()});
170174 } else {
171175 constantIns.push_back ({inDimName, 0 });
172176 nonConstantIns.push_back ({inDimName, idx});
173177 }
174178 }
175- SmallVector<int32_t > constantComponent =
176- llvm::to_vector (llvm::make_second_range (layout.apply (constantIns)));
177179
180+ // Compute constant part of the output and wrap it as values
178181 Value zero = b.i32_val (0 );
179182 SmallVector<std::pair<StringAttr, Value>> outIndices;
180- for (auto [i, outDimName ] : llvm::enumerate ( layout.getOutDimNames () )) {
181- if (constantComponent[i] == 0 )
183+ for (auto [outDimName, constant ] : layout.apply (constantIns )) {
184+ if (constant == 0 )
182185 outIndices.push_back ({outDimName, zero});
183186 else
184- outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
187+ outIndices.push_back ({outDimName, b.i32_val (constant)});
188+ }
189+
190+ if (nonConstantIns.size () == 0 ) {
191+ return outIndices;
185192 }
186- // Happy path: Only one output.
187- if (outIndices.size () == 1 ) {
188- SmallVector<StringAttr> inDimNames;
189- // Concatenate input
190- Value x = b.i32_val (0 );
193+
194+ // Concatenate input
195+ Value x = b.i32_val (0 );
196+ if (nonConstantIns.size () == 1 ) {
197+ x = nonConstantIns[0 ].second ;
198+ } else {
191199 int shift = 0 ;
192200 for (auto [inDimName, idx] : nonConstantIns) {
193- inDimNames.push_back (inDimName);
194201 x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
195202 shift += layout.getInDimSizeLog2 (inDimName);
196203 }
197- // Flatten ins
198- auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
199- matrix = matrix.flattenIns ();
200- auto out = triton::gpu::matrixVectorProd (b, matrix, x);
201- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
202- return outIndices;
203204 }
204205
205- for (auto [inDimName, idx] : indices) {
206- if (idx.getDefiningOp <LLVM::ConstantOp>()) {
207- continue ;
208- }
206+ // Remove constant input dims from the layout and flatten it
207+ auto inDimNames = llvm::to_vector (llvm::make_first_range (nonConstantIns));
208+ auto matrix = layout.sublayout (
209+ inDimNames, llvm::to_vector (llvm::make_first_range (outIndices)));
210+ auto flatMatrix = matrix.flattenIns ().flattenOuts ();
211+
212+ // Lower the matrix-vector product
213+ auto out = triton::gpu::matrixVectorProd (b, flatMatrix, x);
209214
210- int nBits = layout.getInDimSizeLog2 (inDimName);
211- for (int i = 0 ; i < nBits; i++) {
212- Value bit = b.and_ (idx, b.i32_val (1 << i));
213- Value bit_is_zero = b.icmp_eq (bit, zero);
214- for (auto &[outDimName, outIdx] : outIndices) {
215- int32_t basis = layout.getBasis (inDimName, i, outDimName);
216- if (basis == 0 )
217- continue ;
218- outIdx = b.xor_ (outIdx, b.select (bit_is_zero, zero, b.i32_val (basis)));
219- }
215+ // Unpack the output
216+ if (matrix.getNumOutDims () == 1 ) {
217+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
218+ } else {
219+ assert (llvm::equal (matrix.getOutDimNames (),
220+ llvm::make_first_range (outIndices)));
221+ int shift = 0 ;
222+ for (auto &[dimName, outIdx] : outIndices) {
223+ auto outDimSizeLog2 = layout.getOutDimSizeLog2 (dimName);
224+ auto mask = (1 << outDimSizeLog2) - 1 ;
225+ outIdx = b.xor_ (outIdx,
226+ b.and_ (b.lshr (out, b.i32_val (shift)), b.i32_val (mask)));
227+ shift += outDimSizeLog2;
220228 }
221229 }
222-
223230 return outIndices;
224231}
225232
0 commit comments