@@ -1728,68 +1728,34 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
17281728 //
17291729 // largeT = [F16, Bf16, F32]
17301730 // CvtPkScalePk8${largeT}${smallT}
1731+ Value castedSource =
1732+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17311733
17321734 if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16 ()) {
1733- // CvtPkScalePk8F16Fp4Op
1734- // i32
1735- Value castedSource =
1736- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17371735 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp4Op>(
17381736 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17391737 } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16 ()) {
1740- // CvtPkScalePk8F16Fp8Op
1741- // vector<8xf8E4M3FN>
1742- Value castedSource =
1743- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17441738 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp8Op>(
17451739 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17461740 } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16 ()) {
1747- // CvtPkScalePk8F16Bf8Op
1748- // vector<8xf8E5M2>
1749- Value castedSource =
1750- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17511741 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Bf8Op>(
17521742 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17531743 } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16 ()) {
1754- // CvtPkScalePk8Bf16Fp4Op
1755- // i32
1756- Value castedSource =
1757- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17581744 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp4Op>(
17591745 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17601746 } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16 ()) {
1761- // CvtPkScalePk8Bf16Fp8Op
1762- // vector<8xf8E4M3FN>
1763- Value castedSource =
1764- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17651747 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp8Op>(
17661748 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17671749 } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16 ()) {
1768- // CvtPkScalePk8Bf16Bf8Op
1769- // vector<8xf8E5M2>
1770- Value castedSource =
1771- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17721750 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Bf8Op>(
17731751 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17741752 } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32 ()) {
1775- // CvtPkScalePk8F32Fp4Op
1776- // i32
1777- Value castedSource =
1778- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17791753 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp4Op>(
17801754 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17811755 } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32 ()) {
1782- // CvtPkScalePk8F32Fp8Op
1783- // vector<8xf8E4M3FN>
1784- Value castedSource =
1785- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17861756 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp8Op>(
17871757 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17881758 } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32 ()) {
1789- // CvtPkScalePk8F32Bf8Op
1790- // vector<8xf8E5M2>
1791- Value castedSource =
1792- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17931759 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Bf8Op>(
17941760 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17951761 }
@@ -1801,39 +1767,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
18011767
18021768 // CvtPkScalePk16${largeT}${smallT}
18031769 else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16 ()) {
1804- // CvtPkScale16F16Fp6Op
1805- Value castedSource =
1806- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18071770 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Fp6Op>(
18081771 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18091772 } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16 ()) {
1810- // CvtPkScale16F16Bf6Op
1811- Value castedSource =
1812- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18131773 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Bf6Op>(
18141774 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18151775 } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16 ()) {
1816- // CvtPkScale16Bf16Fp6Op
1817- Value castedSource =
1818- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18191776 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Fp6Op>(
18201777 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18211778 } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16 ()) {
1822- // CvtPkScale16Bf16Bf6Op
1823- Value castedSource =
1824- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18251779 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Bf6Op>(
18261780 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18271781 } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32 ()) {
1828- // CvtPkScale16F32Fp6Op
1829- Value castedSource =
1830- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18311782 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Fp6Op>(
18321783 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18331784 } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32 ()) {
1834- // CvtPkScale16F32Bf6Op
1835- Value castedSource =
1836- LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18371785 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Bf6Op>(
18381786 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18391787 } else {
0 commit comments