@@ -306,6 +306,53 @@ Value getLaneId(OpBuilder &rewriter, Location loc) {
306
306
return getLaneAndWarpId (rewriter, loc).first ;
307
307
}
308
308
309
+ // Helper function: applies linear layout vectorized over register indices
310
+ SmallVector<SmallVector<std::pair<StringAttr, Value>>>
311
+ applyLinearLayoutVec (Location loc, RewriterBase &rewriter,
312
+ const LinearLayout &layout,
313
+ ArrayRef<std::pair<StringAttr, Value>> indices,
314
+ ArrayRef<uint32_t > registers) {
315
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
316
+ MLIRContext *ctx = rewriter.getContext ();
317
+
318
+ StringAttr kRegister = str_attr (" register" );
319
+
320
+ // Precompute the base (with register = 0)
321
+ SmallVector<std::pair<StringAttr, Value>> indicesWithZeroReg;
322
+ for (const auto &[attr, val] : indices) {
323
+ if (attr == kRegister )
324
+ indicesWithZeroReg.emplace_back (attr, b.i32_val (0 ));
325
+ else
326
+ indicesWithZeroReg.emplace_back (attr, val);
327
+ }
328
+
329
+ auto baseIndices =
330
+ applyLinearLayout (loc, rewriter, layout, indicesWithZeroReg);
331
+
332
+ SmallVector<SmallVector<std::pair<StringAttr, Value>>> ret;
333
+
334
+ // Iterate over registers, applying XOR trick
335
+ for (auto reg : registers) {
336
+ SmallVector<std::pair<StringAttr, int32_t >> constRegIndices;
337
+ for (const auto &[attr, val] : indices) {
338
+ constRegIndices.emplace_back (attr, attr == kRegister ? reg : 0 );
339
+ }
340
+ auto regIndices = layout.apply (constRegIndices);
341
+
342
+ SmallVector<std::pair<StringAttr, Value>> combinedIndices;
343
+ for (auto [base, regIdx] : llvm::zip (baseIndices, regIndices)) {
344
+ assert (base.first == regIdx.first );
345
+ Value combined = b.xor_ (base.second , b.i32_val (regIdx.second ));
346
+ combinedIndices.emplace_back (base.first , combined);
347
+ }
348
+
349
+ ret.push_back (combinedIndices);
350
+ }
351
+
352
+ return ret;
353
+ }
354
+
355
+ // Refactored emitIndices function using applyLinearLayoutVec
309
356
SmallVector<SmallVector<Value>>
310
357
emitIndices (Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
311
358
Attribute layout, RankedTensorType type, bool withCTAOffset) {
@@ -315,8 +362,6 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
315
362
316
363
LinearLayout ll = triton::gpu::toLinearLayout (shape, layout);
317
364
318
- // TODO(jlebar): We could add strong typing if we wanted; for now this is
319
- // "stringly typed".
320
365
StringAttr kRegister = str_attr (" register" );
321
366
StringAttr kLane = str_attr (" lane" );
322
367
StringAttr kWarp = str_attr (" warp" );
@@ -325,38 +370,29 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
325
370
auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
326
371
Value blockId =
327
372
withCTAOffset ? target.getClusterCTAId (rewriter, loc) : b.i32_val (0 );
373
+
374
+ SmallVector<std::pair<StringAttr, Value>> commonIndices = {
375
+ {kRegister , b.i32_val (0 )},
376
+ {kLane , laneId},
377
+ {kWarp , warpId},
378
+ {kBlock , blockId}};
379
+
380
+ // Vectorize over registers
381
+ SmallVector<uint32_t > registerIndices;
382
+ for (unsigned reg = 0 ; reg < ll.getInDimSize (kRegister ); ++reg)
383
+ registerIndices.push_back (reg);
384
+
385
+ auto vecIndices =
386
+ applyLinearLayoutVec (loc, rewriter, ll, commonIndices, registerIndices);
387
+
328
388
unsigned rank = shape.size ();
329
389
SmallVector<SmallVector<Value>> ret;
330
- // Linear layout function is split in two parts below:
331
- // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
332
- // idxs = idxsBase xor idxsReg
333
- //
334
- // L(0, t, w, b) part is the same for all registers,
335
- // so we hoist it out of the main register loop in the below.
336
- //
337
- // This approach produces code with lower register pressure and
338
- // less computations, compared to fused L(r,t,w,b) method.
339
- auto idxsBase = applyLinearLayout (loc, rewriter, ll,
340
- {{kRegister , b.i32_val (0 )},
341
- {kLane , laneId},
342
- {kWarp , warpId},
343
- {kBlock , blockId}});
344
- for (unsigned reg = 0 ; reg < ll.getInDimSize (str_attr (" register" )); reg++) {
345
- auto idxsReg =
346
- ll.apply ({{kRegister , reg}, {kLane , 0 }, {kWarp , 0 }, {kBlock , 0 }});
347
- SmallVector<std::pair<StringAttr, Value>> idxs;
348
- for (auto [idxBase, idxReg] : llvm::zip (idxsBase, idxsReg)) {
349
- auto dimName = idxBase.first ;
350
- assert (dimName == idxReg.first &&
351
- " dim names of block+warp+thread and register idx should be equal" );
352
- auto idx = b.xor_ (idxBase.second , b.i32_val (idxReg.second ));
353
- idxs.emplace_back (dimName, idx);
354
- }
355
- assert (idxs.size () == rank);
356
- for (unsigned k = 0 ; k < rank; ++k) {
357
- assert (idxs[k].first == str_attr (" dim" + std::to_string (k)));
358
- }
359
- ret.push_back (llvm::to_vector (llvm::make_second_range (idxs)));
390
+ for (auto &indices : vecIndices) {
391
+ SmallVector<Value> vals;
392
+ assert (indices.size () == rank);
393
+ for (auto &idx : indices)
394
+ vals.push_back (idx.second );
395
+ ret.push_back (vals);
360
396
}
361
397
362
398
return ret;
0 commit comments