@@ -152,10 +152,6 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
152152 auto b = TritonLLVMOpBuilder (loc, rewriter);
153153 assert (layout.getNumInDims () == indices.size ());
154154 assert (llvm::equal (layout.getInDimNames (), llvm::make_first_range (indices)));
155- // Trivial layout
156- if (layout.getNumOutDims () == 0 ) {
157- return {};
158- }
159155
160156 // This function can emit a lot of MLIR code, which ultimately makes
161157 // compilation slow. (We think this shouldn't be the case -- it's not *that*
@@ -169,65 +165,62 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
169165 SmallVector<std::pair<StringAttr, int32_t >> constantIns;
170166 SmallVector<std::pair<StringAttr, Value>> nonConstantIns;
171167 for (auto [inDimName, idx] : indices) {
172- APInt constant;
173- if ( matchPattern (idx, m_ConstantInt (&constant))) {
174- constantIns. push_back ( {inDimName, constant.getSExtValue ()});
168+ if ( auto constant = idx. getDefiningOp <LLVM::ConstantOp>()) {
169+ constantIns. push_back (
170+ {inDimName, cast<IntegerAttr>( constant.getValue ()). getInt ()});
175171 } else {
176172 constantIns.push_back ({inDimName, 0 });
177173 nonConstantIns.push_back ({inDimName, idx});
178174 }
179175 }
176+ SmallVector<int32_t > constantComponent =
177+ llvm::to_vector (llvm::make_second_range (layout.apply (constantIns)));
180178
181- // Compute constant part of the output and wrap it as values
182179 Value zero = b.i32_val (0 );
183180 SmallVector<std::pair<StringAttr, Value>> outIndices;
184- for (auto [outDimName, constant ] : layout.apply (constantIns )) {
185- if (constant == 0 )
181+ for (auto [i, outDimName ] : llvm::enumerate ( layout.getOutDimNames () )) {
182+ if (constantComponent[i] == 0 )
186183 outIndices.push_back ({outDimName, zero});
187184 else
188- outIndices.push_back ({outDimName, b.i32_val (constant)});
189- }
190-
191- if (nonConstantIns.size () == 0 ) {
192- return outIndices;
185+ outIndices.push_back ({outDimName, b.i32_val (constantComponent[i])});
193186 }
194-
195- // Concatenate input
196- Value x = b.i32_val (0 );
197- if (nonConstantIns.size () == 1 ) {
198- x = nonConstantIns[0 ].second ;
199- } else {
187+ // Happy path: Only one output.
188+ if (outIndices.size () == 1 ) {
189+ SmallVector<StringAttr> inDimNames;
190+ // Concatenate input
191+ Value x = b.i32_val (0 );
200192 int shift = 0 ;
201193 for (auto [inDimName, idx] : nonConstantIns) {
194+ inDimNames.push_back (inDimName);
202195 x = b.or_ (x, b.shl (idx, b.i32_val (shift)));
203196 shift += layout.getInDimSizeLog2 (inDimName);
204197 }
198+ // Flatten ins
199+ auto matrix = layout.sublayout (inDimNames, outIndices[0 ].first );
200+ matrix = matrix.flattenIns ();
201+ auto out = triton::gpu::matrixVectorProd (b, matrix, x);
202+ outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
203+ return outIndices;
205204 }
206205
207- // Remove constant input dims from the layout and flatten it
208- auto inDimNames = llvm::to_vector (llvm::make_first_range (nonConstantIns));
209- auto matrix = layout.sublayout (
210- inDimNames, llvm::to_vector (llvm::make_first_range (outIndices)));
211- auto flatMatrix = matrix.flattenIns ().flattenOuts ();
212-
213- // Lower the matrix-vector product
214- auto out = triton::gpu::matrixVectorProd (b, flatMatrix, x);
206+ for (auto [inDimName, idx] : indices) {
207+ if (idx.getDefiningOp <LLVM::ConstantOp>()) {
208+ continue ;
209+ }
215210
216- // Unpack the output
217- if (matrix.getNumOutDims () == 1 ) {
218- outIndices[0 ].second = b.xor_ (outIndices[0 ].second , out);
219- } else {
220- assert (llvm::equal (matrix.getOutDimNames (),
221- llvm::make_first_range (outIndices)));
222- int shift = 0 ;
223- for (auto &[dimName, outIdx] : outIndices) {
224- auto outDimSizeLog2 = layout.getOutDimSizeLog2 (dimName);
225- auto mask = (1 << outDimSizeLog2) - 1 ;
226- outIdx = b.xor_ (outIdx,
227- b.and_ (b.lshr (out, b.i32_val (shift)), b.i32_val (mask)));
228- shift += outDimSizeLog2;
211+ int nBits = layout.getInDimSizeLog2 (inDimName);
212+ for (int i = 0 ; i < nBits; i++) {
213+ Value bit = b.and_ (idx, b.i32_val (1 << i));
214+ Value bit_is_zero = b.icmp_eq (bit, zero);
215+ for (auto &[outDimName, outIdx] : outIndices) {
216+ int32_t basis = layout.getBasis (inDimName, i, outDimName);
217+ if (basis == 0 )
218+ continue ;
219+ outIdx = b.xor_ (outIdx, b.select (bit_is_zero, zero, b.i32_val (basis)));
220+ }
229221 }
230222 }
223+
231224 return outIndices;
232225}
233226
0 commit comments