@@ -296,6 +296,53 @@ std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc) {
296
296
return {laneId, warpId};
297
297
}
298
298
299
+ // Helper function: applies linear layout vectorized over register indices
300
+ SmallVector<SmallVector<std::pair<StringAttr, Value>>>
301
+ applyLinearLayoutVec (Location loc, RewriterBase &rewriter,
302
+ const LinearLayout &layout,
303
+ ArrayRef<std::pair<StringAttr, Value>> indices,
304
+ ArrayRef<uint32_t > registers) {
305
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
306
+ MLIRContext *ctx = rewriter.getContext ();
307
+
308
+ StringAttr kRegister = str_attr (" register" );
309
+
310
+ // Precompute the base (with register = 0)
311
+ SmallVector<std::pair<StringAttr, Value>> indicesWithZeroReg;
312
+ for (const auto &[attr, val] : indices) {
313
+ if (attr == kRegister )
314
+ indicesWithZeroReg.emplace_back (attr, b.i32_val (0 ));
315
+ else
316
+ indicesWithZeroReg.emplace_back (attr, val);
317
+ }
318
+
319
+ auto baseIndices =
320
+ applyLinearLayout (loc, rewriter, layout, indicesWithZeroReg);
321
+
322
+ SmallVector<SmallVector<std::pair<StringAttr, Value>>> ret;
323
+
324
+ // Iterate over registers, applying XOR trick
325
+ for (auto reg : registers) {
326
+ SmallVector<std::pair<StringAttr, int32_t >> constRegIndices;
327
+ for (const auto &[attr, val] : indices) {
328
+ constRegIndices.emplace_back (attr, attr == kRegister ? reg : 0 );
329
+ }
330
+ auto regIndices = layout.apply (constRegIndices);
331
+
332
+ SmallVector<std::pair<StringAttr, Value>> combinedIndices;
333
+ for (auto [base, regIdx] : llvm::zip (baseIndices, regIndices)) {
334
+ assert (base.first == regIdx.first );
335
+ Value combined = b.xor_ (base.second , b.i32_val (regIdx.second ));
336
+ combinedIndices.emplace_back (base.first , combined);
337
+ }
338
+
339
+ ret.push_back (combinedIndices);
340
+ }
341
+
342
+ return ret;
343
+ }
344
+
345
+ // Refactored emitIndices function using applyLinearLayoutVec
299
346
SmallVector<SmallVector<Value>>
300
347
emitIndices (Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
301
348
Attribute layout, RankedTensorType type, bool withCTAOffset) {
@@ -305,8 +352,6 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
305
352
306
353
LinearLayout ll = triton::gpu::toLinearLayout (shape, layout);
307
354
308
- // TODO(jlebar): We could add strong typing if we wanted; for now this is
309
- // "stringly typed".
310
355
StringAttr kRegister = str_attr (" register" );
311
356
StringAttr kLane = str_attr (" lane" );
312
357
StringAttr kWarp = str_attr (" warp" );
@@ -315,38 +360,29 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
315
360
auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
316
361
Value blockId =
317
362
withCTAOffset ? target.getClusterCTAId (rewriter, loc) : b.i32_val (0 );
363
+
364
+ SmallVector<std::pair<StringAttr, Value>> commonIndices = {
365
+ {kRegister , b.i32_val (0 )},
366
+ {kLane , laneId},
367
+ {kWarp , warpId},
368
+ {kBlock , blockId}};
369
+
370
+ // Vectorize over registers
371
+ SmallVector<uint32_t > registerIndices;
372
+ for (unsigned reg = 0 ; reg < ll.getInDimSize (kRegister ); ++reg)
373
+ registerIndices.push_back (reg);
374
+
375
+ auto vecIndices =
376
+ applyLinearLayoutVec (loc, rewriter, ll, commonIndices, registerIndices);
377
+
318
378
unsigned rank = shape.size ();
319
379
SmallVector<SmallVector<Value>> ret;
320
- // Linear layout function is split in two parts below:
321
- // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
322
- // idxs = idxsBase xor idxsReg
323
- //
324
- // L(0, t, w, b) part is the same for all registers,
325
- // so we hoist it out of the main register loop in the below.
326
- //
327
- // This approach produces code with lower register pressure and
328
- // less computations, compared to fused L(r,t,w,b) method.
329
- auto idxsBase = applyLinearLayout (loc, rewriter, ll,
330
- {{kRegister , b.i32_val (0 )},
331
- {kLane , laneId},
332
- {kWarp , warpId},
333
- {kBlock , blockId}});
334
- for (unsigned reg = 0 ; reg < ll.getInDimSize (str_attr (" register" )); reg++) {
335
- auto idxsReg =
336
- ll.apply ({{kRegister , reg}, {kLane , 0 }, {kWarp , 0 }, {kBlock , 0 }});
337
- SmallVector<std::pair<StringAttr, Value>> idxs;
338
- for (auto [idxBase, idxReg] : llvm::zip (idxsBase, idxsReg)) {
339
- auto dimName = idxBase.first ;
340
- assert (dimName == idxReg.first &&
341
- " dim names of block+warp+thread and register idx should be equal" );
342
- auto idx = b.xor_ (idxBase.second , b.i32_val (idxReg.second ));
343
- idxs.emplace_back (dimName, idx);
344
- }
345
- assert (idxs.size () == rank);
346
- for (unsigned k = 0 ; k < rank; ++k) {
347
- assert (idxs[k].first == str_attr (" dim" + std::to_string (k)));
348
- }
349
- ret.push_back (llvm::to_vector (llvm::make_second_range (idxs)));
380
+ for (auto &indices : vecIndices) {
381
+ SmallVector<Value> vals;
382
+ assert (indices.size () == rank);
383
+ for (auto &idx : indices)
384
+ vals.push_back (idx.second );
385
+ ret.push_back (vals);
350
386
}
351
387
352
388
return ret;
@@ -781,8 +817,7 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
781
817
ArrayRef<Value> srcVals,
782
818
const SharedMemoryObject &smemObj, Location loc,
783
819
RewriterBase &rewriter,
784
- const TargetInfoBase &target,
785
- std::pair<size_t , Type> *const llvmOpCount) {
820
+ const TargetInfoBase &target) {
786
821
auto b = TritonLLVMOpBuilder (loc, rewriter);
787
822
bool success = emitTransferBetweenRegistersAndShared (
788
823
srcTy, dstTy, elemLlvmTy, /* maxVecElems=*/ std::nullopt, smemObj, loc,
@@ -797,10 +832,6 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
797
832
b.store (vec, vecAddr)
798
833
.setAlignment (vecTy.getNumElements () *
799
834
elemLlvmTy.getIntOrFloatBitWidth () / 8 );
800
- if (llvmOpCount) {
801
- ++(llvmOpCount->first );
802
- llvmOpCount->second = vecTy;
803
- }
804
835
});
805
836
806
837
if (!success)
0 commit comments