@@ -151,10 +151,6 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
151151 auto b = TritonLLVMOpBuilder (loc, rewriter);
152152 assert (layout.getNumInDims () == indices.size ());
153153 assert (llvm::equal (layout.getInDimNames (), llvm::make_first_range (indices)));
154- // Trivial layout
155- if (layout.getNumOutDims () == 0 ) {
156- return {};
157- }
158154
159155 // This function can emit a lot of MLIR code, which ultimately makes
160156 // compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -168,65 +164,62 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
168164 SmallVector<std::pair<StringAttr, int32_t >> constantIns;
169165 SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
170166 for (auto [inDimName, idx] : indices) {
171- APInt constant;
172- if ( matchPattern (idx, m_ConstantInt (&constant))) {
173- constantIns. push_back ( {inDimName, constant.getSExtValue ()});
167+ if ( auto constant = idx. getDefiningOp <LLVM::ConstantOp>()) {
168+ constantIns. push_back (
169+ {inDimName, cast<IntegerAttr>( constant.getValue ()). getInt ()});
174170 } else {
175171 constantIns.push_back ({inDimName, 0 });
176172 nonConstantIns.push_back ({inDimName, idx});
177173 }
178174 }
175+ SmallVector<int32_t > constantComponent =
176+ llvm::to_vector (llvm::make_second_range (layout.apply (constantIns)));
179177
180- // Compute constant part of the output and wrap it as values
181178 Value zero = b.i32_val (0 );
182179 SmallVector<std::pair<StringAttr, Value>> outIndices;
183- for (auto [outDimName, constant ] : layout.apply (constantIns )) {
184- if (constant == 0 )
180+ for (auto [i, outDimName ] : llvm::enumerate ( layout.getOutDimNames () )) {
181+ if (constantComponent[i] == 0 )
185182 outIndices.push_back ({outDimName, zero});
186183 else
187- outIndices.push_back ({outDimName, b.i32_val (constant)});
188- }
189-
190- if (nonConstantIns.size () == 0 ) {
191- return outIndices;
184+ outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
192185 }
193-
194- // Concatenate input
195- Value x = b.i32_val (0 );
196- if (nonConstantIns.size () == 1 ) {
197- x = nonConstantIns[0 ].second ;
198- } else {
186+ // Happy path: Only one output.
187+ if (outIndices.size () == 1 ) {
188+ SmallVector<StringAttr> inDimNames;
189+ // Concatenate input
190+ Value x = b.i32_val (0 );
199191 int shift = 0 ;
200192 for (auto [inDimName, idx] : nonConstantIns) {
193+ inDimNames.push_back (inDimName);
201194 x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
202195 shift += layout.getInDimSizeLog2 (inDimName);
203196 }
197+ // Flatten ins
198+ auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
199+ matrix = matrix.flattenIns ();
200+ auto out = triton::gpu::matrixVectorProd (b, matrix, x);
201+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
202+ return outIndices;
204203 }
205204
206- // Remove constant input dims from the layout and flatten it
207- auto inDimNames = llvm::to_vector (llvm::make_first_range (nonConstantIns));
208- auto matrix = layout.sublayout (
209- inDimNames, llvm::to_vector (llvm::make_first_range (outIndices)));
210- auto flatMatrix = matrix.flattenIns ().flattenOuts ();
211-
212- // Lower the matrix-vector product
213- auto out = triton::gpu::matrixVectorProd (b, flatMatrix, x);
205+ for (auto [inDimName, idx] : indices) {
206+ if (idx.getDefiningOp <LLVM::ConstantOp>()) {
207+ continue ;
208+ }
214209
215- // Unpack the output
216- if (matrix.getNumOutDims () == 1 ) {
217- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
218- } else {
219- assert (llvm::equal (matrix.getOutDimNames (),
220- llvm::make_first_range (outIndices)));
221- int shift = 0 ;
222- for (auto &[dimName, outIdx] : outIndices) {
223- auto outDimSizeLog2 = layout.getOutDimSizeLog2 (dimName);
224- auto mask = (1 << outDimSizeLog2) - 1 ;
225- outIdx = b.xor_ (outIdx,
226- b.and_ (b.lshr (out, b.i32_val (shift)), b.i32_val (mask)));
227- shift += outDimSizeLog2;
210+ int nBits = layout.getInDimSizeLog2 (inDimName);
211+ for (int i = 0 ; i < nBits; i++) {
212+ Value bit = b.and_ (idx, b.i32_val (1 << i));
213+ Value bit_is_zero = b.icmp_eq (bit, zero);
214+ for (auto &[outDimName, outIdx] : outIndices) {
215+ int32_t basis = layout.getBasis (inDimName, i, outDimName);
216+ if (basis == 0 )
217+ continue ;
218+ outIdx = b.xor_ (outIdx, b.select (bit_is_zero, zero, b.i32_val (basis)));
219+ }
228220 }
229221 }
222+
230223 return outIndices;
231224}
232225
0 commit comments