Skip to content

Commit a7a853e

Browse files
committed
Refactor NFC
1 parent 0f8f3c4 commit a7a853e

File tree

1 file changed

+40
-52
lines changed

1 file changed

+40
-52
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 40 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,94 +1708,88 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
17081708
Value castedScale =
17091709
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
17101710

1711+
Value source = adaptor.getSource();
1712+
Type packedType;
1713+
if (isa<Float4E2M1FNType>(srcElemType)) {
1714+
packedType = i32;
1715+
} else if (isa<Float8E4M3FNType>(srcElemType) ||
1716+
isa<Float8E5M2Type>(srcElemType)) {
1717+
packedType = VectorType::get(2, i32);
1718+
} else if (isa<Float6E2M3FNType>(srcElemType) ||
1719+
isa<Float6E3M2FNType>(srcElemType)) {
1720+
packedType = VectorType::get(3, i32);
1721+
} else {
1722+
llvm_unreachable("invalid element type for scaled ext");
1723+
}
17111724
// Ok, so we need to construct ops depending on the sourceType and targetType.
17121725
// smallT = [Fp4, Fp8, Bf8]
17131726
// Bf8 = E5M2
17141727
// Fp8 = E4M3
17151728
//
17161729
// largeT = [F16, Bf16, F32]
17171730
// CvtPkScalePk8${largeT}${smallT}
1718-
Value source = adaptor.getSource();
17191731

17201732
if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16()) {
17211733
// CvtPkScalePk8F16Fp4Op
17221734
// i32
17231735
Value castedSource =
1724-
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
1736+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17251737
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
17261738
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17271739
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16()) {
17281740
// CvtPkScalePk8F16Fp8Op
17291741
// vector<8xf8E4M3FN>
1730-
1731-
// vector<2xi32>
1732-
VectorType v2xi32 = VectorType::get(2, i32);
1733-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
1734-
1742+
Value castedSource =
1743+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17351744
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
17361745
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17371746
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16()) {
17381747
// CvtPkScalePk8F16Bf8Op
17391748
// vector<8xf8E5M2>
1740-
1741-
// vector<2xi32>
1742-
VectorType v2xi32 = VectorType::get(2, i32);
1743-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
1744-
1749+
Value castedSource =
1750+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17451751
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
17461752
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17471753
} else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16()) {
17481754
// CvtPkScalePk8Bf16Fp4Op
17491755
// i32
17501756
Value castedSource =
1751-
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
1757+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17521758
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
17531759
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17541760
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16()) {
17551761
// CvtPkScalePk8Bf16Fp8Op
17561762
// vector<8xf8E4M3FN>
1757-
1758-
// vector<2xi32>
1759-
VectorType v2xi32 = VectorType::get(2, i32);
1760-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
1761-
1763+
Value castedSource =
1764+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17621765
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
17631766
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17641767
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16()) {
17651768
// CvtPkScalePk8Bf16Bf8Op
17661769
// vector<8xf8E5M2>
1767-
1768-
// vector<2xi32>
1769-
VectorType v2xi32 = VectorType::get(2, i32);
1770-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
1771-
1770+
Value castedSource =
1771+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17721772
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
17731773
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17741774
} else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32()) {
17751775
// CvtPkScalePk8F32Fp4Op
17761776
// i32
17771777
Value castedSource =
1778-
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource());
1778+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17791779
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
17801780
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17811781
} else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32()) {
17821782
// CvtPkScalePk8F32Fp8Op
17831783
// vector<8xf8E4M3FN>
1784-
1785-
// vector<2xi32>
1786-
VectorType v2xi32 = VectorType::get(2, i32);
1787-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
1788-
1784+
Value castedSource =
1785+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17891786
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
17901787
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
17911788
} else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32()) {
17921789
// CvtPkScalePk8F32Bf8Op
17931790
// vector<8xf8E5M2>
1794-
1795-
// vector<2xi32>
1796-
VectorType v2xi32 = VectorType::get(2, i32);
1797-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source);
1798-
1791+
Value castedSource =
1792+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
17991793
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
18001794
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
18011795
}
@@ -1808,44 +1802,38 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
18081802
// CvtPkScalePk16${largeT}${smallT}
18091803
else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16()) {
18101804
// CvtPkScale16F16Fp6Op
1811-
VectorType v3xi32 = VectorType::get(3, i32);
1812-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
1813-
1805+
Value castedSource =
1806+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
18141807
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
18151808
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
18161809
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16()) {
18171810
// CvtPkScale16F16Bf6Op
1818-
VectorType v3xi32 = VectorType::get(3, i32);
1819-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
1820-
1811+
Value castedSource =
1812+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
18211813
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
18221814
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
18231815
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16()) {
18241816
// CvtPkScale16Bf16Fp6Op
1825-
VectorType v3xi32 = VectorType::get(3, i32);
1826-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
1827-
1817+
Value castedSource =
1818+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
18281819
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
18291820
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
18301821
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16()) {
18311822
// CvtPkScale16Bf16Bf6Op
1832-
VectorType v3xi32 = VectorType::get(3, i32);
1833-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
1834-
1823+
Value castedSource =
1824+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
18351825
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
18361826
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
18371827
} else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32()) {
18381828
// CvtPkScale16F32Fp6Op
1839-
VectorType v3xi32 = VectorType::get(3, i32);
1840-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
1841-
1829+
Value castedSource =
1830+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
18421831
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
18431832
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
18441833
} else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32()) {
18451834
// CvtPkScale16F32Bf6Op
1846-
VectorType v3xi32 = VectorType::get(3, i32);
1847-
Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source);
1848-
1835+
Value castedSource =
1836+
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
18491837
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
18501838
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
18511839
} else {

0 commit comments

Comments
 (0)