@@ -272,13 +272,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
272
272
const LinearLayout &dstLayout,
273
273
OpAdaptor adaptor,
274
274
ConversionPatternRewriter &rewriter) const {
275
- MLIRContext *ctx = op.getContext ();
276
- auto loc = op.getLoc ();
277
- auto b = TritonLLVMOpBuilder (loc, rewriter);
278
- auto srcTy = op.getSrc ().getType ();
279
- auto dstTy = op.getType ();
280
-
281
- assert (cvtNeedsSharedMemory (srcTy, dstTy));
275
+ assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
282
276
283
277
// Try to use swizzling to implement the conversion
284
278
// HACK Remove once AMD tests pass for the swizzling path
@@ -287,52 +281,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
287
281
return success ();
288
282
}
289
283
290
- SmallVector<Value> inVals =
291
- unpackLLElements (loc, adaptor.getSrc (), rewriter);
292
- assert (!inVals.empty ());
293
-
294
- // We munge the input values by converting i<n> (n<8) elements to i8 and
295
- // pointers to i64. This is necessary because TargetInfo::loadDShared and
296
- // storeDShared can't handle vectors of pointers or sub-byte elements.
297
- auto elemTy = srcTy.getElementType ();
298
- auto isSubByteInt =
299
- elemTy.isInteger () && elemTy.getIntOrFloatBitWidth () < 8 ;
300
- auto isPtr = isa<triton::PointerType>(elemTy);
301
- auto llvmElemTyOrig = getTypeConverter ()->convertType (elemTy);
302
- if (isSubByteInt)
303
- elemTy = IntegerType::get (elemTy.getContext (), 8 );
304
- else if (isPtr)
305
- elemTy = IntegerType::get (elemTy.getContext (), 64 );
306
- auto llvmElemTy = getTypeConverter ()->convertType (elemTy);
307
-
308
- // Munge input values
309
- for (const auto &it : llvm::enumerate (inVals)) {
310
- if (isSubByteInt) {
311
- inVals[it.index ()] = b.zext (llvmElemTy, it.value ());
312
- } else if (isPtr) {
313
- inVals[it.index ()] = b.ptrtoint (llvmElemTy, it.value ());
314
- }
315
- }
316
-
317
- // Pretty sure this is the identity function ATM
318
- // It'd be better to simply call `quotient({kBlock})` and
319
- // remove kBlock from transferWithinBlockImpl
320
- auto srcLayoutWithinBlock = getLayoutWithinBlock (srcLayout);
321
- auto dstLayoutWithinBlock = getLayoutWithinBlock (dstLayout);
322
- SmallVector<Value> outVals = transferWithinBlockImpl (
323
- inVals, op, srcLayoutWithinBlock, dstLayoutWithinBlock, rewriter);
324
-
325
- // Unmunge output values
326
- for (const auto &it : llvm::enumerate (outVals)) {
327
- if (isSubByteInt) {
328
- outVals[it.index ()] = b.trunc (llvmElemTyOrig, it.value ());
329
- } else if (isPtr) {
330
- outVals[it.index ()] = b.inttoptr (llvmElemTyOrig, it.value ());
331
- }
332
- }
284
+ Value result = transferWithinBlockPadding (op, adaptor.getSrc (), targetInfo,
285
+ getTypeConverter (), rewriter);
333
286
334
- Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
335
- op.getType ());
336
287
rewriter.replaceOp (op, result);
337
288
return success ();
338
289
}
@@ -343,207 +294,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
343
294
DecomposedWarpConversion decomposed,
344
295
OpAdaptor adaptor,
345
296
ConversionPatternRewriter &rewriter) const ;
346
-
347
- SmallVector<Value>
348
- transferWithinBlockImpl (ArrayRef<Value> inVals, ConvertLayoutOp op,
349
- const LinearLayout &srcLayout,
350
- const LinearLayout &dstLayout,
351
- ConversionPatternRewriter &rewriter) const {
352
- MLIRContext *ctx = op.getContext ();
353
- auto loc = op.getLoc ();
354
- auto b = TritonLLVMOpBuilder (loc, rewriter);
355
-
356
- StringAttr kRegister = str_attr (" register" );
357
- StringAttr kLane = str_attr (" lane" );
358
- StringAttr kWarp = str_attr (" warp" );
359
- StringAttr kBlock = str_attr (" block" );
360
- StringAttr kOffset = str_attr (" offset" );
361
- StringAttr kIteration = str_attr (" iteration" );
362
-
363
- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
364
-
365
- auto scratchConfig =
366
- getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
367
- auto tensorShapePerCTA = convertType<unsigned , int64_t >(getShapePerCTA (
368
- op.getSrc ().getType ().getEncoding (), op.getType ().getShape ()));
369
- // Input dims: [offset, iteration, block]
370
- // Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
371
- LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion (
372
- ctx, tensorShapePerCTA, scratchConfig.repShape , scratchConfig.order );
373
-
374
- // Layout for the store from registers to shared memory.
375
- //
376
- // Note: If two threads in the same warp write to the same shmem offset, the
377
- // hardware resolves that without a stall or a bank conflict. Therefore we
378
- // don't need to avoid duplicate writes.
379
- // Input dims: [reg, lane, warp]
380
- // Output dims: [offset, iteration]
381
- bool isStMatrix = targetInfo.canUseStMatrix (
382
- op.getSrc ().getType (), scratchConfig.repShape ,
383
- scratchConfig.paddedRepShape , scratchConfig.order ,
384
- /* swizzleByteSize=*/ 0 );
385
- LinearLayout shmemStoreLayout =
386
- isStMatrix ? chooseStMatrixLayout (ctx, op.getSrc ().getType (),
387
- /* swizzleByteSize=*/ 0 )
388
- : srcLayout.invertAndCompose (sharedLayout);
389
-
390
- const int shmemAllocatedNumElems =
391
- getNumScratchElements (scratchConfig.paddedRepShape );
392
- assert (shmemStoreLayout.getOutDimSize (kOffset ) <= shmemAllocatedNumElems);
393
-
394
- // Layout for the load from shmem to registers.
395
- LinearLayout shmemLoadLayout = dstLayout.invertAndCompose (sharedLayout);
396
-
397
- // Check that the `register` fully determines the `iteration`. That is,
398
- // each thread does exactly the same reads and writes to shmem on each
399
- // iteration, just with different input/output registers.
400
- assert (
401
- shmemStoreLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
402
- assert (
403
- shmemLoadLayout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
404
-
405
- // iteration -> registers
406
- SmallVector<SmallVector<int >> inRegsForIter =
407
- collectRegsForIter (ctx, shmemStoreLayout);
408
- SmallVector<SmallVector<int >> outRegsForIter =
409
- collectRegsForIter (ctx, shmemLoadLayout);
410
-
411
- Value smemBase =
412
- LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
413
- auto sharedPtrTy = smemBase.getType ();
414
- Type elemTy = inVals[0 ].getType ();
415
- auto outSize = shmemLoadLayout.getInDimSize (kRegister );
416
- auto iterations = sharedLayout.getInDimSize (kIteration );
417
- assert (scratchConfig.inVec * iterations <= inVals.size ());
418
- assert (scratchConfig.outVec * iterations <= outSize);
419
-
420
- // Check only one dimension has been padded.
421
- // This means the difference between the padded shape and the original shape
422
- // should only be in one dimension, specifically in
423
- // `scratchConfig.order[0]`.
424
- auto rank = scratchConfig.repShape .size ();
425
- for (auto i = 0 ; i < rank; i++) {
426
- if (i == scratchConfig.order [0 ]) {
427
- continue ;
428
- }
429
- assert (scratchConfig.repShape [i] == scratchConfig.paddedRepShape [i]);
430
- }
431
- auto paddedStride = scratchConfig.repShape [scratchConfig.order [0 ]];
432
- auto paddedSize =
433
- scratchConfig.paddedRepShape [scratchConfig.order [0 ]] - paddedStride;
434
-
435
- // Linear layout function is split in two parts below:
436
- //
437
- // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
438
- // offset = regBase xor regIdx
439
- //
440
- // It is the same hack as what we've done in the emitIndices function to get
441
- // around performance issues on AMD GPUs
442
- auto getVecAddr = [&](LinearLayout &layout, Value ®Base,
443
- int regSlice) -> Value {
444
- auto regIdx = layout
445
- .apply ({{kRegister , regSlice},
446
- {kLane , 0 },
447
- {kWarp , 0 },
448
- {kBlock , 0 }})[0 ]
449
- .second ;
450
- Value offset = b.xor_ (regBase, b.i32_val (regIdx));
451
- if (paddedSize > 0 ) {
452
- assert (llvm::isPowerOf2_32 (paddedStride));
453
- assert (llvm::isPowerOf2_32 (paddedSize));
454
- auto rshiftVal = llvm::Log2_32 (paddedStride);
455
- auto lshiftVal = llvm::Log2_32 (paddedSize);
456
- offset = b.add (
457
- b.shl (b.lshr (offset, b.i32_val (rshiftVal)), b.i32_val (lshiftVal)),
458
- offset);
459
- }
460
- auto vecAddr = b.gep (sharedPtrTy, elemTy, smemBase, offset,
461
- LLVM::GEPNoWrapFlags::inbounds);
462
- return vecAddr;
463
- };
464
-
465
- auto storeBase = applyLinearLayout (loc, rewriter, shmemStoreLayout,
466
- {{kRegister , b.i32_val (0 )},
467
- {kLane , laneId},
468
- {kWarp , warpId},
469
- {kBlock , b.i32_val (0 )}})[0 ]
470
- .second ;
471
- auto loadBase = applyLinearLayout (loc, rewriter, shmemLoadLayout,
472
- {{kRegister , b.i32_val (0 )},
473
- {kLane , laneId},
474
- {kWarp , warpId},
475
- {kBlock , b.i32_val (0 )}})[0 ]
476
- .second ;
477
- // register idx -> Value
478
- llvm::MapVector<int , Value> outVals;
479
- for (int i = 0 ; i < iterations; i++) {
480
- if (i != 0 )
481
- b.barrier ();
482
-
483
- auto &inRegs = inRegsForIter[i];
484
- auto &outRegs = outRegsForIter[i];
485
-
486
- // When using `stmatrix`, we can store `inVec` elements even if they are
487
- // not contiguous
488
- auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut ()
489
- : scratchConfig.inVec ;
490
- for (int j = 0 ; j < inVals.size () / iterations; j += inVec) {
491
- auto inRegSlice = inRegs[j];
492
- Value vecAddr = getVecAddr (shmemStoreLayout, storeBase, inRegSlice);
493
- SmallVector<Value> inValsVec;
494
- for (int k = 0 ; k < inVec; k++)
495
- inValsVec.push_back (inVals[inRegSlice + k]);
496
- Value valsVec = packLLVector (loc, inValsVec, rewriter);
497
- if (isStMatrix) {
498
- targetInfo.storeMatrixShared (rewriter, loc, vecAddr, valsVec);
499
- } else {
500
- targetInfo.storeDShared (rewriter, loc, vecAddr, std::nullopt, valsVec,
501
- /* pred=*/ b.true_val ());
502
- }
503
- }
504
-
505
- b.barrier ();
506
-
507
- for (int j = 0 ; j < outSize / iterations; j += scratchConfig.outVec ) {
508
- auto outRegSlice = outRegs[j];
509
- auto vecAddr = getVecAddr (shmemLoadLayout, loadBase, outRegSlice);
510
- Value valsVec =
511
- targetInfo.loadDShared (rewriter, loc, vecAddr, std::nullopt,
512
- vec_ty (elemTy, scratchConfig.outVec ),
513
- /* pred=*/ b.true_val ());
514
- for (Value v : unpackLLVector (loc, valsVec, rewriter))
515
- outVals[outRegSlice++] = v;
516
- }
517
- }
518
-
519
- SmallVector<Value> outValsVec;
520
- for (size_t i = 0 ; i < outVals.size (); i++)
521
- outValsVec.push_back (outVals[i]);
522
- return outValsVec;
523
- }
524
-
525
- // Determine which registers are read/written in which iteration of the shmem
526
- // transfer specified by `layout`.
527
- SmallVector<SmallVector<int > /* registers*/ >
528
- collectRegsForIter (MLIRContext *ctx, const LinearLayout &layout) const {
529
- StringAttr kRegister = str_attr (" register" );
530
- StringAttr kLane = str_attr (" lane" );
531
- StringAttr kWarp = str_attr (" warp" );
532
- StringAttr kBlock = str_attr (" block" );
533
- StringAttr kIteration = str_attr (" iteration" );
534
-
535
- // The choice of iteration should be determined only by the register. That
536
- // is, it should be correct to split the register dimension into iterations.
537
- assert (layout.sublayoutIsZero ({kLane , kWarp , kBlock }, {kIteration }));
538
-
539
- LinearLayout sublayout = layout.sublayout ({kRegister }, {kIteration });
540
- SmallVector<SmallVector<int >> ret (sublayout.getOutDimSize (kIteration ));
541
- for (int reg = 0 ; reg < sublayout.getInDimSize (kRegister ); reg++) {
542
- auto idx = sublayout.apply ({{kRegister , reg}});
543
- ret[idx.begin ()->second ].push_back (reg);
544
- }
545
- return ret;
546
- }
547
297
};
548
298
549
299
} // namespace
0 commit comments