@@ -1506,16 +1506,17 @@ struct ExtPackedFp8OpLowering final
15061506 ConversionPatternRewriter &rewriter) const override ;
15071507};
15081508
1509- struct ScaledExtPacked816OpLowering final
1510- : public ConvertOpToLLVMPattern<ScaledExtPacked816Op > {
1511- ScaledExtPacked816OpLowering (const LLVMTypeConverter &converter,
1512- Chipset chipset)
1513- : ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op >(converter),
1509+ struct ScaledExtPackedMatrixOpLowering final
1510+ : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp > {
1511+ ScaledExtPackedMatrixOpLowering (const LLVMTypeConverter &converter,
1512+ Chipset chipset)
1513+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp >(converter),
15141514 chipset (chipset) {}
15151515 Chipset chipset;
15161516
15171517 LogicalResult
1518- matchAndRewrite (ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1518+ matchAndRewrite (ScaledExtPackedMatrixOp op,
1519+ ScaledExtPackedMatrixOpAdaptor adaptor,
15191520 ConversionPatternRewriter &rewriter) const override ;
15201521};
15211522
@@ -1627,34 +1628,35 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
16271628 return success ();
16281629}
16291630
1630- int32_t getScaleSel (int32_t blockSize, unsigned bitWidth,
1631- int32_t firstScaleLane, int32_t firstScaleByte) {
1632- // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
1633- // operations, the attributes blockSize, sourceType, firstScaleLane and
1631+ int32_t getScaleSel (int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1632+ int32_t firstScaleByte) {
1633+ // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1634+ // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
16341635 // firstScaleByte are merged into a single attribute scaleSel. This is how
1635- // those values are merged together.
1636+ // those values are merged together. (Note: scaleWaveHalf isn't a high-level
1637+ // attribute but is derifed from firstScaleLane).
16361638 assert (llvm::is_contained ({16 , 32 }, blockSize));
16371639 assert (llvm::is_contained (llvm::ArrayRef<unsigned >{4 , 6 , 8 }, bitWidth));
16381640
1639- const bool is_fp8 = bitWidth == 8 ;
1640- const bool is_block_16 = blockSize == 16 ;
1641+ const bool isFp8 = bitWidth == 8 ;
1642+ const bool isBlock16 = blockSize == 16 ;
16411643
1642- if (!is_fp8 ) {
1643- int bit_0 = is_block_16 ;
1644+ if (!isFp8 ) {
1645+ int32_t bit0 = isBlock16 ;
16441646 assert (llvm::is_contained ({0 , 1 , 2 }, firstScaleByte));
1645- int bit_1 = (firstScaleByte == 2 ) << 1 ;
1646- assert (llvm::is_contained ({0 , 1 }, firstScaleLane ));
1647- int bit_2 = firstScaleLane << 2 ;
1648- return bit_2 | bit_1 | bit_0 ;
1647+ int32_t bit1 = (firstScaleByte == 2 ) << 1 ;
1648+ assert (llvm::is_contained ({0 , 1 }, scaleWaveHalf ));
1649+ int32_t bit2 = scaleWaveHalf << 2 ;
1650+ return bit2 | bit1 | bit0 ;
16491651 }
16501652
1651- int bit_0 = is_block_16 ;
1653+ int32_t bit0 = isBlock16 ;
16521654 // firstScaleByte is guaranteed to be defined by two bits.
16531655 assert (llvm::is_contained ({0 , 1 , 2 , 3 }, firstScaleByte));
1654- int bit_2_and_1 = firstScaleByte << 1 ;
1655- assert (llvm::is_contained ({0 , 1 }, firstScaleLane ));
1656- int bit_3 = firstScaleLane << 3 ;
1657- int bits = bit_3 | bit_2_and_1 | bit_0 ;
1656+ int32_t bits2and1 = firstScaleByte << 1 ;
1657+ assert (llvm::is_contained ({0 , 1 }, scaleWaveHalf ));
1658+ int32_t bit3 = scaleWaveHalf << 3 ;
1659+ int32_t bits = bit3 | bits2and1 | bit0 ;
16581660 // These are invalid cases.
16591661 assert (!llvm::is_contained (
16601662 {0b0011 , 0b0101 , 0b0111 , 0b1000 , 0b1001 , 0b1011 , 0b1111 }, bits));
@@ -1717,8 +1719,8 @@ scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
17171719 " instructions" );
17181720}
17191721
1720- LogicalResult ScaledExtPacked816OpLowering ::matchAndRewrite (
1721- ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1722+ LogicalResult ScaledExtPackedMatrixOpLowering ::matchAndRewrite (
1723+ ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
17221724 ConversionPatternRewriter &rewriter) const {
17231725 using fp4 = Float4E2M1FNType;
17241726 using fp8 = Float8E4M3FNType;
@@ -1732,7 +1734,9 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
17321734 " Scaled fp packed conversion instructions are not available on target "
17331735 " architecture and their emulation is not implemented" );
17341736 }
1735- int32_t firstScaleLane = op.getFirstScaleLane ();
1737+ // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
1738+ // is being selected.
1739+ int32_t scaleWaveHalf = op.getFirstScaleLane () / 16 ;
17361740 int32_t firstScaleByte = op.getFirstScaleByte ();
17371741 int32_t blockSize = op.getBlockSize ();
17381742 auto sourceType = cast<VectorType>(op.getSource ().getType ());
@@ -1770,7 +1774,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
17701774 " no intrinsic matching packed scaled conversion on the given chipset" );
17711775
17721776 int32_t scaleSel =
1773- getScaleSel (blockSize, bitWidth, firstScaleLane , firstScaleByte);
1777+ getScaleSel (blockSize, bitWidth, scaleWaveHalf , firstScaleByte);
17741778 Value castedScale =
17751779 LLVM::BitcastOp::create (rewriter, loc, i32 , adaptor.getScale ());
17761780 Value castedSource =
@@ -2388,27 +2392,26 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
23882392 RewritePatternSet &patterns,
23892393 Chipset chipset) {
23902394 populateAMDGPUMemorySpaceAttributeConversions (converter);
2391- patterns
2392- .add <FatRawBufferCastLowering,
2393- RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2394- RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2395- RawBufferOpLowering<RawBufferAtomicFaddOp,
2396- ROCDL::RawPtrBufferAtomicFaddOp>,
2397- RawBufferOpLowering<RawBufferAtomicFmaxOp,
2398- ROCDL::RawPtrBufferAtomicFmaxOp>,
2399- RawBufferOpLowering<RawBufferAtomicSmaxOp,
2400- ROCDL::RawPtrBufferAtomicSmaxOp>,
2401- RawBufferOpLowering<RawBufferAtomicUminOp,
2402- ROCDL::RawPtrBufferAtomicUminOp>,
2403- RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2404- ROCDL::RawPtrBufferAtomicCmpSwap>,
2405- AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2406- SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2407- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
2408- ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2409- PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2410- GatherToLDSOpLowering, TransposeLoadOpLowering,
2411- AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering>(converter,
2412- chipset);
2395+ patterns.add <
2396+ FatRawBufferCastLowering,
2397+ RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2398+ RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2399+ RawBufferOpLowering<RawBufferAtomicFaddOp,
2400+ ROCDL::RawPtrBufferAtomicFaddOp>,
2401+ RawBufferOpLowering<RawBufferAtomicFmaxOp,
2402+ ROCDL::RawPtrBufferAtomicFmaxOp>,
2403+ RawBufferOpLowering<RawBufferAtomicSmaxOp,
2404+ ROCDL::RawPtrBufferAtomicSmaxOp>,
2405+ RawBufferOpLowering<RawBufferAtomicUminOp,
2406+ ROCDL::RawPtrBufferAtomicUminOp>,
2407+ RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2408+ ROCDL::RawPtrBufferAtomicCmpSwap>,
2409+ AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2410+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2411+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2412+ ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2413+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2414+ GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2415+ AMDGPUMakeDmaBaseLowering>(converter, chipset);
24132416 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
24142417}
0 commit comments