Skip to content

Commit 33ef57e

Browse files
committed
Using
1 parent ed66571 commit 33ef57e

File tree

1 file changed

+24
-34
lines changed

1 file changed

+24
-34
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,7 +1620,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
16201620
// firstScaleByte are merged into a single attribute scaleSel. This is how
16211621
// those values are merged together.
16221622
assert(llvm::is_contained({16, 32}, blockSize));
1623-
assert(llvm::is_contained({4, 6, 8}, bitWidth));
1623+
assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
16241624

16251625
const bool is_fp8 = bitWidth == 8;
16261626
const bool is_block_16 = blockSize == 16;
@@ -1653,6 +1653,11 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
16531653
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
16541654
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
16551655
ConversionPatternRewriter &rewriter) const {
1656+
using fp4 = Float4E2M1FNType;
1657+
using fp8 = Float8E4M3FNType;
1658+
using bf8 = Float8E5M2Type;
1659+
using fp6 = Float6E2M3FNType;
1660+
using bf6 = Float6E3M2FNType;
16561661
int32_t firstScaleLane = op.getFirstScaleLane();
16571662
int32_t firstScaleByte = op.getFirstScaleByte();
16581663
int32_t blockSize = op.getBlockSize();
@@ -1671,79 +1676,64 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
16711676

16721677
Value source = adaptor.getSource();
16731678
Type packedType;
1674-
if (isa<Float4E2M1FNType>(srcElemType)) {
1679+
if (isa<fp4>(srcElemType)) {
16751680
packedType = i32;
16761681
packedType = getTypeConverter()->convertType(packedType);
1677-
} else if (isa<Float8E4M3FNType>(srcElemType) ||
1678-
isa<Float8E5M2Type>(srcElemType)) {
1682+
} else if (isa<fp8, bf8>(srcElemType)) {
16791683
packedType = VectorType::get(2, i32);
16801684
packedType = getTypeConverter()->convertType(packedType);
1681-
} else if (isa<Float6E2M3FNType>(srcElemType) ||
1682-
isa<Float6E3M2FNType>(srcElemType)) {
1685+
} else if (isa<fp6, bf6>(srcElemType)) {
16831686
packedType = VectorType::get(3, i32);
16841687
packedType = getTypeConverter()->convertType(packedType);
16851688
} else {
16861689
llvm_unreachable("invalid element type for scaled ext");
16871690
}
1688-
// smallT = [Fp4, Fp8, Bf8]
1689-
// Bf8 = E5M2
1690-
// Fp8 = E4M3
1691-
//
1692-
// largeT = [F16, Bf16, F32]
1693-
// CvtPkScalePk8${largeT}${smallT}
16941691
Value castedSource =
16951692
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
16961693

1697-
if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
1694+
if (isa<fp4>(srcElemType) && destElemType.isF16()) {
16981695
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
16991696
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1700-
} else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF16()) {
1697+
} else if (isa<fp8>(srcElemType) && destElemType.isF16()) {
17011698
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
17021699
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1703-
} else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF16()) {
1700+
} else if (isa<bf8>(srcElemType) && destElemType.isF16()) {
17041701
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
17051702
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1706-
} else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isBF16()) {
1703+
} else if (isa<fp4>(srcElemType) && destElemType.isBF16()) {
17071704
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
17081705
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1709-
} else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isBF16()) {
1706+
} else if (isa<fp8>(srcElemType) && destElemType.isBF16()) {
17101707
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
17111708
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1712-
} else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isBF16()) {
1709+
} else if (isa<bf8>(srcElemType) && destElemType.isBF16()) {
17131710
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
17141711
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1715-
} else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF32()) {
1712+
} else if (isa<fp4>(srcElemType) && destElemType.isF32()) {
17161713
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
17171714
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1718-
} else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF32()) {
1715+
} else if (isa<fp8>(srcElemType) && destElemType.isF32()) {
17191716
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
17201717
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1721-
} else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF32()) {
1718+
} else if (isa<bf8>(srcElemType) && destElemType.isF32()) {
17221719
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
17231720
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1724-
}
1725-
// smallT = [Fp6, Bf6]
1726-
// Fp6 = Float6E2M3FN
1727-
// Bf6 = Float6E3M2FN
1728-
// largeT = [F16, Bf16, F32]
1729-
//
1730-
// CvtPkScalePk16${largeT}${smallT}
1731-
else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16()) {
1721+
} else if (isa<fp6>(srcElemType) && destElemType.isF16()) {
17321722
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
17331723
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1734-
} else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF16()) {
1724+
} else if (isa<bf6>(srcElemType) && destElemType.isF16()) {
17351725
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
17361726
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1737-
} else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isBF16()) {
1727+
} else if (isa<fp6>(srcElemType) && destElemType.isBF16()) {
17381728
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
17391729
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1740-
} else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isBF16()) {
1730+
} else if (isa<bf6>(srcElemType) && destElemType.isBF16()) {
17411731
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
17421732
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1743-
} else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF32()) {
1733+
} else if (isa<fp6>(srcElemType) && destElemType.isF32()) {
17441734
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
17451735
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1746-
} else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF32()) {
1736+
} else if (isa<bf6>(srcElemType) && destElemType.isF32()) {
17471737
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
17481738
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17491739
} else {

0 commit comments

Comments
 (0)