@@ -419,6 +419,112 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
419
419
}
420
420
};
421
421
422
+ // TODO: AMDGPU backend already have all this bitpacking logic, we should move
423
+ // it to some common place.
424
+ // / Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
425
+ // / Vmcnt = Waitcnt[3:0] (pre-gfx9)
426
+ // / Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
427
+ // / Vmcnt = Waitcnt[15:10] (gfx11)
428
+ // / Expcnt = Waitcnt[6:4] (pre-gfx11)
429
+ // / Expcnt = Waitcnt[2:0] (gfx11)
430
+ // / Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
431
+ // / Lgkmcnt = Waitcnt[13:8] (gfx10)
432
+ // / Lgkmcnt = Waitcnt[9:4] (gfx11)
433
+ static FailureOr<unsigned > encodeWaitcnt (Chipset chipset, unsigned vmcnt,
434
+ unsigned expcnt, unsigned lgkmcnt) {
435
+ if (chipset.majorVersion < 9 ) {
436
+ vmcnt = std::min (15u , vmcnt);
437
+ expcnt = std::min (7u , expcnt);
438
+ lgkmcnt = std::min (15u , lgkmcnt);
439
+ return vmcnt | (expcnt << 4 ) | (lgkmcnt << 8 );
440
+ }
441
+ if (chipset.majorVersion == 9 ) {
442
+ vmcnt = std::min (63u , vmcnt);
443
+ expcnt = std::min (7u , expcnt);
444
+ lgkmcnt = std::min (15u , lgkmcnt);
445
+ unsigned lowBits = vmcnt & 0xF ;
446
+ unsigned highBits = (vmcnt >> 4 ) << 14 ;
447
+ unsigned otherCnts = (expcnt << 4 ) | (lgkmcnt << 8 );
448
+ return lowBits | highBits | otherCnts;
449
+ }
450
+ if (chipset.majorVersion == 10 ) {
451
+ vmcnt = std::min (63u , vmcnt);
452
+ expcnt = std::min (7u , expcnt);
453
+ lgkmcnt = std::min (63u , lgkmcnt);
454
+ unsigned lowBits = vmcnt & 0xF ;
455
+ unsigned highBits = (vmcnt >> 4 ) << 14 ;
456
+ unsigned otherCnts = (expcnt << 4 ) | (lgkmcnt << 8 );
457
+ return lowBits | highBits | otherCnts;
458
+ }
459
+ if (chipset.majorVersion == 11 ) {
460
+ vmcnt = std::min (63u , vmcnt);
461
+ expcnt = std::min (7u , expcnt);
462
+ lgkmcnt = std::min (63u , lgkmcnt);
463
+ return (vmcnt << 10 ) | expcnt | (lgkmcnt << 4 );
464
+ }
465
+ return failure ();
466
+ }
467
+
468
+ struct MemoryCounterWaitOpLowering
469
+ : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
470
+ MemoryCounterWaitOpLowering (const LLVMTypeConverter &converter,
471
+ Chipset chipset)
472
+ : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
473
+ chipset (chipset) {}
474
+
475
+ Chipset chipset;
476
+
477
+ LogicalResult
478
+ matchAndRewrite (MemoryCounterWaitOp op, OpAdaptor adaptor,
479
+ ConversionPatternRewriter &rewriter) const override {
480
+ if (chipset.majorVersion >= 12 ) {
481
+ Location loc = op.getLoc ();
482
+ if (std::optional<int > ds = adaptor.getDs ())
483
+ rewriter.create <ROCDL::WaitDscntOp>(loc, *ds);
484
+
485
+ if (std::optional<int > load = adaptor.getLoad ())
486
+ rewriter.create <ROCDL::WaitLoadcntOp>(loc, *load);
487
+
488
+ if (std::optional<int > store = adaptor.getStore ())
489
+ rewriter.create <ROCDL::WaitStorecntOp>(loc, *store);
490
+
491
+ if (std::optional<int > exp = adaptor.getExp ())
492
+ rewriter.create <ROCDL::WaitExpcntOp>(loc, *exp);
493
+
494
+ rewriter.eraseOp (op);
495
+ return success ();
496
+ }
497
+
498
+ auto getVal = [](Attribute attr) -> unsigned {
499
+ if (attr)
500
+ return cast<IntegerAttr>(attr).getInt ();
501
+
502
+ // This value will be clamped to the maximum value for the chipset.
503
+ return 1024 ;
504
+ };
505
+ unsigned ds = getVal (adaptor.getDsAttr ());
506
+ unsigned exp = getVal (adaptor.getExpAttr ());
507
+
508
+ unsigned vmcnt = 1024 ;
509
+ Attribute load = adaptor.getLoadAttr ();
510
+ Attribute store = adaptor.getStoreAttr ();
511
+ if (load && store) {
512
+ vmcnt = getVal (load) + getVal (store);
513
+ } else if (load) {
514
+ vmcnt = getVal (load);
515
+ } else if (store) {
516
+ vmcnt = getVal (store);
517
+ }
518
+
519
+ FailureOr<unsigned > waitcnt = encodeWaitcnt (chipset, vmcnt, exp, ds);
520
+ if (failed (waitcnt))
521
+ return op.emitOpError (" unsupported chipset" );
522
+
523
+ rewriter.replaceOpWithNewOp <ROCDL::SWaitcntOp>(op, *waitcnt);
524
+ return success ();
525
+ }
526
+ };
527
+
422
528
struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern <LDSBarrierOp> {
423
529
LDSBarrierOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
424
530
: ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
@@ -1825,9 +1931,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
1825
1931
ROCDL::RawPtrBufferAtomicUminOp>,
1826
1932
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1827
1933
ROCDL::RawPtrBufferAtomicCmpSwap>,
1828
- AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering ,
1829
- MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering ,
1830
- ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1934
+ AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering ,
1935
+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering ,
1936
+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1831
1937
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1832
1938
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1833
1939
TransposeLoadOpLowering>(converter, chipset);
0 commit comments