Skip to content

Commit 909c9aa

Browse files
authored
[mlir][amdgpu] Add lowerings for ScaledExtPacked816 (#168123)
* Adds lowerings for amdgpy.scaled_ext_packed816 * updates verifiers
1 parent 3f60d22 commit 909c9aa

File tree

5 files changed

+379
-54
lines changed

5 files changed

+379
-54
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 185 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
4343
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
4444
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
4545
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
46+
constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
4647

4748
/// Convert an unsigned number `val` to i32.
4849
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -1149,7 +1150,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
11491150
k, isRDNA3);
11501151

11511152
// Handle gfx1250.
1152-
if (chipset == Chipset{12, 5, 0})
1153+
if (chipset == kGfx1250)
11531154
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
11541155
elemDestType, k);
11551156

@@ -1300,7 +1301,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13001301
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
13011302
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
13021303

1303-
bool isGFX1250 = chipset >= Chipset(12, 5, 0);
1304+
bool isGFX1250 = chipset >= kGfx1250;
13041305

13051306
// The WMMA operations represent vectors of bf16s as vectors of i16s
13061307
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
@@ -1505,6 +1506,19 @@ struct ExtPackedFp8OpLowering final
15051506
ConversionPatternRewriter &rewriter) const override;
15061507
};
15071508

1509+
struct ScaledExtPacked816OpLowering final
1510+
: public ConvertOpToLLVMPattern<ScaledExtPacked816Op> {
1511+
ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter,
1512+
Chipset chipset)
1513+
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter),
1514+
chipset(chipset) {}
1515+
Chipset chipset;
1516+
1517+
LogicalResult
1518+
matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1519+
ConversionPatternRewriter &rewriter) const override;
1520+
};
1521+
15081522
struct PackedTrunc2xFp8OpLowering final
15091523
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
15101524
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1613,6 +1627,170 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
16131627
return success();
16141628
}
16151629

1630+
int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
1631+
int32_t firstScaleLane, int32_t firstScaleByte) {
1632+
// When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
1633+
// operations, the attributes blockSize, sourceType, firstScaleLane and
1634+
// firstScaleByte are merged into a single attribute scaleSel. This is how
1635+
// those values are merged together.
1636+
assert(llvm::is_contained({16, 32}, blockSize));
1637+
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
1638+
1639+
const bool is_fp8 = bitWidth == 8;
1640+
const bool is_block_16 = blockSize == 16;
1641+
1642+
if (!is_fp8) {
1643+
int bit_0 = is_block_16;
1644+
assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1645+
int bit_1 = (firstScaleByte == 2) << 1;
1646+
assert(llvm::is_contained({0, 1}, firstScaleLane));
1647+
int bit_2 = firstScaleLane << 2;
1648+
return bit_2 | bit_1 | bit_0;
1649+
}
1650+
1651+
int bit_0 = is_block_16;
1652+
// firstScaleByte is guaranteed to be defined by two bits.
1653+
assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1654+
int bit_2_and_1 = firstScaleByte << 1;
1655+
assert(llvm::is_contained({0, 1}, firstScaleLane));
1656+
int bit_3 = firstScaleLane << 3;
1657+
int bits = bit_3 | bit_2_and_1 | bit_0;
1658+
// These are invalid cases.
1659+
assert(!llvm::is_contained(
1660+
{0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1661+
return bits;
1662+
}
1663+
1664+
static std::optional<StringRef>
1665+
scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1666+
using fp4 = Float4E2M1FNType;
1667+
using fp8 = Float8E4M3FNType;
1668+
using bf8 = Float8E5M2Type;
1669+
using fp6 = Float6E2M3FNType;
1670+
using bf6 = Float6E3M2FNType;
1671+
if (isa<fp4>(srcElemType)) {
1672+
if (destElemType.isF16())
1673+
return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1674+
if (destElemType.isBF16())
1675+
return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1676+
if (destElemType.isF32())
1677+
return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1678+
return std::nullopt;
1679+
}
1680+
if (isa<fp8>(srcElemType)) {
1681+
if (destElemType.isF16())
1682+
return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1683+
if (destElemType.isBF16())
1684+
return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1685+
if (destElemType.isF32())
1686+
return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1687+
return std::nullopt;
1688+
}
1689+
if (isa<bf8>(srcElemType)) {
1690+
if (destElemType.isF16())
1691+
return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1692+
if (destElemType.isBF16())
1693+
return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1694+
if (destElemType.isF32())
1695+
return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1696+
return std::nullopt;
1697+
}
1698+
if (isa<fp6>(srcElemType)) {
1699+
if (destElemType.isF16())
1700+
return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1701+
if (destElemType.isBF16())
1702+
return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1703+
if (destElemType.isF32())
1704+
return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1705+
return std::nullopt;
1706+
}
1707+
if (isa<bf6>(srcElemType)) {
1708+
if (destElemType.isF16())
1709+
return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1710+
if (destElemType.isBF16())
1711+
return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1712+
if (destElemType.isF32())
1713+
return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1714+
return std::nullopt;
1715+
}
1716+
llvm_unreachable("invalid combination of element types for packed conversion "
1717+
"instructions");
1718+
}
1719+
1720+
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
1721+
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1722+
ConversionPatternRewriter &rewriter) const {
1723+
using fp4 = Float4E2M1FNType;
1724+
using fp8 = Float8E4M3FNType;
1725+
using bf8 = Float8E5M2Type;
1726+
using fp6 = Float6E2M3FNType;
1727+
using bf6 = Float6E3M2FNType;
1728+
Location loc = op.getLoc();
1729+
if (chipset != kGfx1250) {
1730+
return rewriter.notifyMatchFailure(
1731+
loc,
1732+
"Scaled fp packed conversion instructions are not available on target "
1733+
"architecture and their emulation is not implemented");
1734+
}
1735+
int32_t firstScaleLane = op.getFirstScaleLane();
1736+
int32_t firstScaleByte = op.getFirstScaleByte();
1737+
int32_t blockSize = op.getBlockSize();
1738+
auto sourceType = cast<VectorType>(op.getSource().getType());
1739+
auto srcElemType = cast<FloatType>(sourceType.getElementType());
1740+
unsigned bitWidth = srcElemType.getWidth();
1741+
int32_t scaleSel =
1742+
getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
1743+
1744+
auto targetType = cast<VectorType>(op.getResult().getType());
1745+
auto destElemType = cast<FloatType>(targetType.getElementType());
1746+
IntegerType i32 = rewriter.getI32Type();
1747+
Value castedScale =
1748+
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
1749+
1750+
Value source = adaptor.getSource();
1751+
Type llvmResultType = typeConverter->convertType(op.getResult().getType());
1752+
Type packedType = nullptr;
1753+
if (isa<fp4>(srcElemType)) {
1754+
packedType = i32;
1755+
packedType = getTypeConverter()->convertType(packedType);
1756+
} else if (isa<fp8, bf8>(srcElemType)) {
1757+
packedType = VectorType::get(2, i32);
1758+
packedType = getTypeConverter()->convertType(packedType);
1759+
} else if (isa<fp6, bf6>(srcElemType)) {
1760+
packedType = VectorType::get(3, i32);
1761+
packedType = getTypeConverter()->convertType(packedType);
1762+
} else {
1763+
llvm_unreachable("invalid element type for packed scaled ext");
1764+
}
1765+
1766+
if (!packedType || !llvmResultType) {
1767+
return rewriter.notifyMatchFailure(op, "type conversion failed");
1768+
}
1769+
1770+
Value castedSource =
1771+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
1772+
1773+
std::optional<StringRef> maybeIntrinsic =
1774+
scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
1775+
if (!maybeIntrinsic.has_value())
1776+
return op.emitOpError(
1777+
"no intrinsic matching packed scaled conversion on the given chipset");
1778+
1779+
OperationState loweredOp(loc, *maybeIntrinsic);
1780+
loweredOp.addTypes({llvmResultType});
1781+
loweredOp.addOperands({castedSource, castedScale});
1782+
1783+
SmallVector<NamedAttribute, 1> attrs;
1784+
attrs.push_back(
1785+
NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
1786+
1787+
loweredOp.addAttributes(attrs);
1788+
Operation *lowered = rewriter.create(loweredOp);
1789+
rewriter.replaceOp(op, lowered);
1790+
1791+
return success();
1792+
}
1793+
16161794
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
16171795
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
16181796
ConversionPatternRewriter &rewriter) const {
@@ -2151,9 +2329,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
21512329
ROCDL::RawPtrBufferAtomicCmpSwap>,
21522330
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
21532331
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2154-
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
2155-
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2156-
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2157-
TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
2332+
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
2333+
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2334+
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2335+
GatherToLDSOpLowering, TransposeLoadOpLowering,
2336+
AMDGPUPermlaneLowering>(converter, chipset);
21582337
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
21592338
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -343,28 +343,41 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
343343
//===----------------------------------------------------------------------===//
344344
LogicalResult ScaledExtPacked816Op::verify() {
345345
int blockSize = getBlockSize();
346-
assert((blockSize == 16 || blockSize == 32) && "invalid block size");
346+
assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
347347

348348
int firstScaleByte = getFirstScaleByte();
349+
int firstScaleLane = getFirstScaleLane();
349350
auto sourceType = cast<VectorType>(getSource().getType());
350351
Type elementType = sourceType.getElementType();
351352
auto floatType = cast<FloatType>(elementType);
352-
int bitWidth = floatType.getWidth();
353+
unsigned bitWidth = floatType.getWidth();
353354

354-
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
355-
!llvm::is_contained({0, 1}, firstScaleByte)) {
356-
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
357-
"for f4 and f6.");
358-
}
359-
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
360-
!llvm::is_contained({0, 2}, firstScaleByte)) {
361-
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
362-
"for f4 and f6.");
363-
}
364-
if (bitWidth == 8 && blockSize == 16 &&
365-
!llvm::is_contained({0, 2}, firstScaleByte)) {
366-
return emitOpError(
367-
"blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
355+
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
356+
357+
const bool is_fp8 = bitWidth == 8;
358+
const bool is_block_16 = blockSize == 16;
359+
360+
if (!is_fp8) {
361+
if (is_block_16) {
362+
if (!llvm::is_contained({0, 1}, firstScaleByte)) {
363+
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
364+
"or 1 for f4 and f6.");
365+
}
366+
} else {
367+
if (!llvm::is_contained({0, 2}, firstScaleByte)) {
368+
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
369+
"or 2 for f4 and f6.");
370+
}
371+
}
372+
} else {
373+
if (is_block_16) {
374+
bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
375+
((firstScaleLane == 1) && (firstScaleByte == 2));
376+
if (!is_valid) {
377+
return emitOpError("blockSize of 16 can only have (firstScaleLane, "
378+
"firstScaleByte) be (0, 0) or (1, 2) for f8.");
379+
}
380+
}
368381
}
369382

370383
return success();

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,4 @@ func.func @sched_barrier() {
456456
amdgpu.sched_barrier allow = <valu|all_vmem>
457457
func.return
458458
}
459+

0 commit comments

Comments
 (0)