@@ -43,6 +43,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
4343constexpr Chipset kGfx90a = Chipset(9 , 0 , 0xa );
4444constexpr Chipset kGfx942 = Chipset(9 , 4 , 2 );
4545constexpr Chipset kGfx950 = Chipset(9 , 5 , 0 );
46+ constexpr Chipset kGfx1250 = Chipset(12 , 5 , 0 );
4647
4748// / Convert an unsigned number `val` to i32.
4849static Value convertUnsignedToI32 (ConversionPatternRewriter &rewriter,
@@ -1149,7 +1150,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
11491150 k, isRDNA3);
11501151
11511152 // Handle gfx1250.
1152- if (chipset == Chipset{ 12 , 5 , 0 } )
1153+ if (chipset == kGfx1250 )
11531154 return wmmaOpToIntrinsicGfx1250 (elemSourceType, elemBSourceType,
11541155 elemDestType, k);
11551156
@@ -1300,7 +1301,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13001301 if (chipset.majorVersion != 11 && chipset.majorVersion != 12 )
13011302 return op->emitOpError (" WMMA only supported on gfx11 and gfx12" );
13021303
1303- bool isGFX1250 = chipset >= Chipset ( 12 , 5 , 0 ) ;
1304+ bool isGFX1250 = chipset >= kGfx1250 ;
13041305
13051306 // The WMMA operations represent vectors of bf16s as vectors of i16s
13061307 // (except on gfx1250), so we need to bitcast bfloats to i16 and then
@@ -1505,6 +1506,19 @@ struct ExtPackedFp8OpLowering final
15051506 ConversionPatternRewriter &rewriter) const override ;
15061507};
15071508
1509+ struct ScaledExtPacked816OpLowering final
1510+ : public ConvertOpToLLVMPattern<ScaledExtPacked816Op> {
1511+ ScaledExtPacked816OpLowering (const LLVMTypeConverter &converter,
1512+ Chipset chipset)
1513+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter),
1514+ chipset (chipset) {}
1515+ Chipset chipset;
1516+
1517+ LogicalResult
1518+ matchAndRewrite (ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1519+ ConversionPatternRewriter &rewriter) const override ;
1520+ };
1521+
15081522struct PackedTrunc2xFp8OpLowering final
15091523 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
15101524 PackedTrunc2xFp8OpLowering (const LLVMTypeConverter &converter,
@@ -1613,6 +1627,170 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
16131627 return success ();
16141628}
16151629
1630+ int32_t getScaleSel (int32_t blockSize, unsigned bitWidth,
1631+ int32_t firstScaleLane, int32_t firstScaleByte) {
1632+ // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
1633+ // operations, the attributes blockSize, sourceType, firstScaleLane and
1634+ // firstScaleByte are merged into a single attribute scaleSel. This is how
1635+ // those values are merged together.
1636+ assert (llvm::is_contained ({16 , 32 }, blockSize));
1637+ assert (llvm::is_contained (llvm::ArrayRef<unsigned >{4 , 6 , 8 }, bitWidth));
1638+
1639+ const bool is_fp8 = bitWidth == 8 ;
1640+ const bool is_block_16 = blockSize == 16 ;
1641+
1642+ if (!is_fp8) {
1643+ int bit_0 = is_block_16;
1644+ assert (llvm::is_contained ({0 , 1 , 2 }, firstScaleByte));
1645+ int bit_1 = (firstScaleByte == 2 ) << 1 ;
1646+ assert (llvm::is_contained ({0 , 1 }, firstScaleLane));
1647+ int bit_2 = firstScaleLane << 2 ;
1648+ return bit_2 | bit_1 | bit_0;
1649+ }
1650+
1651+ int bit_0 = is_block_16;
1652+ // firstScaleByte is guaranteed to be defined by two bits.
1653+ assert (llvm::is_contained ({0 , 1 , 2 , 3 }, firstScaleByte));
1654+ int bit_2_and_1 = firstScaleByte << 1 ;
1655+ assert (llvm::is_contained ({0 , 1 }, firstScaleLane));
1656+ int bit_3 = firstScaleLane << 3 ;
1657+ int bits = bit_3 | bit_2_and_1 | bit_0;
1658+ // These are invalid cases.
1659+ assert (!llvm::is_contained (
1660+ {0b0011 , 0b0101 , 0b0111 , 0b1000 , 0b1001 , 0b1011 , 0b1111 }, bits));
1661+ return bits;
1662+ }
1663+
1664+ static std::optional<StringRef>
1665+ scaledExtPacked816ToIntrinsic (Type srcElemType, Type destElemType) {
1666+ using fp4 = Float4E2M1FNType;
1667+ using fp8 = Float8E4M3FNType;
1668+ using bf8 = Float8E5M2Type;
1669+ using fp6 = Float6E2M3FNType;
1670+ using bf6 = Float6E3M2FNType;
1671+ if (isa<fp4>(srcElemType)) {
1672+ if (destElemType.isF16 ())
1673+ return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName ();
1674+ if (destElemType.isBF16 ())
1675+ return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName ();
1676+ if (destElemType.isF32 ())
1677+ return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName ();
1678+ return std::nullopt ;
1679+ }
1680+ if (isa<fp8>(srcElemType)) {
1681+ if (destElemType.isF16 ())
1682+ return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName ();
1683+ if (destElemType.isBF16 ())
1684+ return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName ();
1685+ if (destElemType.isF32 ())
1686+ return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName ();
1687+ return std::nullopt ;
1688+ }
1689+ if (isa<bf8>(srcElemType)) {
1690+ if (destElemType.isF16 ())
1691+ return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName ();
1692+ if (destElemType.isBF16 ())
1693+ return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName ();
1694+ if (destElemType.isF32 ())
1695+ return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName ();
1696+ return std::nullopt ;
1697+ }
1698+ if (isa<fp6>(srcElemType)) {
1699+ if (destElemType.isF16 ())
1700+ return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName ();
1701+ if (destElemType.isBF16 ())
1702+ return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName ();
1703+ if (destElemType.isF32 ())
1704+ return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName ();
1705+ return std::nullopt ;
1706+ }
1707+ if (isa<bf6>(srcElemType)) {
1708+ if (destElemType.isF16 ())
1709+ return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName ();
1710+ if (destElemType.isBF16 ())
1711+ return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName ();
1712+ if (destElemType.isF32 ())
1713+ return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName ();
1714+ return std::nullopt ;
1715+ }
1716+ llvm_unreachable (" invalid combination of element types for packed conversion "
1717+ " instructions" );
1718+ }
1719+
1720+ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite (
1721+ ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1722+ ConversionPatternRewriter &rewriter) const {
1723+ using fp4 = Float4E2M1FNType;
1724+ using fp8 = Float8E4M3FNType;
1725+ using bf8 = Float8E5M2Type;
1726+ using fp6 = Float6E2M3FNType;
1727+ using bf6 = Float6E3M2FNType;
1728+ Location loc = op.getLoc ();
1729+ if (chipset != kGfx1250 ) {
1730+ return rewriter.notifyMatchFailure (
1731+ loc,
1732+ " Scaled fp packed conversion instructions are not available on target "
1733+ " architecture and their emulation is not implemented" );
1734+ }
1735+ int32_t firstScaleLane = op.getFirstScaleLane ();
1736+ int32_t firstScaleByte = op.getFirstScaleByte ();
1737+ int32_t blockSize = op.getBlockSize ();
1738+ auto sourceType = cast<VectorType>(op.getSource ().getType ());
1739+ auto srcElemType = cast<FloatType>(sourceType.getElementType ());
1740+ unsigned bitWidth = srcElemType.getWidth ();
1741+ int32_t scaleSel =
1742+ getScaleSel (blockSize, bitWidth, firstScaleLane, firstScaleByte);
1743+
1744+ auto targetType = cast<VectorType>(op.getResult ().getType ());
1745+ auto destElemType = cast<FloatType>(targetType.getElementType ());
1746+ IntegerType i32 = rewriter.getI32Type ();
1747+ Value castedScale =
1748+ LLVM::BitcastOp::create (rewriter, loc, i32 , adaptor.getScale ());
1749+
1750+ Value source = adaptor.getSource ();
1751+ Type llvmResultType = typeConverter->convertType (op.getResult ().getType ());
1752+ Type packedType = nullptr ;
1753+ if (isa<fp4>(srcElemType)) {
1754+ packedType = i32 ;
1755+ packedType = getTypeConverter ()->convertType (packedType);
1756+ } else if (isa<fp8, bf8>(srcElemType)) {
1757+ packedType = VectorType::get (2 , i32 );
1758+ packedType = getTypeConverter ()->convertType (packedType);
1759+ } else if (isa<fp6, bf6>(srcElemType)) {
1760+ packedType = VectorType::get (3 , i32 );
1761+ packedType = getTypeConverter ()->convertType (packedType);
1762+ } else {
1763+ llvm_unreachable (" invalid element type for packed scaled ext" );
1764+ }
1765+
1766+ if (!packedType || !llvmResultType) {
1767+ return rewriter.notifyMatchFailure (op, " type conversion failed" );
1768+ }
1769+
1770+ Value castedSource =
1771+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
1772+
1773+ std::optional<StringRef> maybeIntrinsic =
1774+ scaledExtPacked816ToIntrinsic (srcElemType, destElemType);
1775+ if (!maybeIntrinsic.has_value ())
1776+ return op.emitOpError (
1777+ " no intrinsic matching packed scaled conversion on the given chipset" );
1778+
1779+ OperationState loweredOp (loc, *maybeIntrinsic);
1780+ loweredOp.addTypes ({llvmResultType});
1781+ loweredOp.addOperands ({castedSource, castedScale});
1782+
1783+ SmallVector<NamedAttribute, 1 > attrs;
1784+ attrs.push_back (
1785+ NamedAttribute (" scaleSel" , rewriter.getI32IntegerAttr (scaleSel)));
1786+
1787+ loweredOp.addAttributes (attrs);
1788+ Operation *lowered = rewriter.create (loweredOp);
1789+ rewriter.replaceOp (op, lowered);
1790+
1791+ return success ();
1792+ }
1793+
16161794LogicalResult ScaledExtPackedOpLowering::matchAndRewrite (
16171795 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
16181796 ConversionPatternRewriter &rewriter) const {
@@ -2151,9 +2329,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
21512329 ROCDL::RawPtrBufferAtomicCmpSwap>,
21522330 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
21532331 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2154- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
2155- PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2156- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2157- TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
2332+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
2333+ ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2334+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2335+ GatherToLDSOpLowering, TransposeLoadOpLowering,
2336+ AMDGPUPermlaneLowering>(converter, chipset);
21582337 patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
21592338}
0 commit comments