@@ -1708,94 +1708,88 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
17081708 Value castedScale =
17091709 LLVM::BitcastOp::create (rewriter, loc, i32 , adaptor.getScale ());
17101710
1711+ Value source = adaptor.getSource ();
1712+ Type packedType;
1713+ if (isa<Float4E2M1FNType>(srcElemType)) {
1714+ packedType = i32 ;
1715+ } else if (isa<Float8E4M3FNType>(srcElemType) ||
1716+ isa<Float8E5M2Type>(srcElemType)) {
1717+ packedType = VectorType::get (2 , i32 );
1718+ } else if (isa<Float6E2M3FNType>(srcElemType) ||
1719+ isa<Float6E3M2FNType>(srcElemType)) {
1720+ packedType = VectorType::get (3 , i32 );
1721+ } else {
1722+ llvm_unreachable (" invalid element type for scaled ext" );
1723+ }
17111724 // Ok, so we need to construct ops depending on the sourceType and targetType.
17121725 // smallT = [Fp4, Fp8, Bf8]
17131726 // Bf8 = E5M2
17141727 // Fp8 = E4M3
17151728 //
17161729 // largeT = [F16, Bf16, F32]
17171730 // CvtPkScalePk8${largeT}${smallT}
1718- Value source = adaptor.getSource ();
17191731
17201732 if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF16 ()) {
17211733 // CvtPkScalePk8F16Fp4Op
17221734 // i32
17231735 Value castedSource =
1724- LLVM::BitcastOp::create (rewriter, loc, i32 , adaptor. getSource () );
1736+ LLVM::BitcastOp::create (rewriter, loc, packedType, source );
17251737 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp4Op>(
17261738 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17271739 } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF16 ()) {
17281740 // CvtPkScalePk8F16Fp8Op
17291741 // vector<8xf8E4M3FN>
1730-
1731- // vector<2xi32>
1732- VectorType v2xi32 = VectorType::get (2 , i32 );
1733- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v2xi32, source);
1734-
1742+ Value castedSource =
1743+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17351744 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp8Op>(
17361745 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17371746 } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF16 ()) {
17381747 // CvtPkScalePk8F16Bf8Op
17391748 // vector<8xf8E5M2>
1740-
1741- // vector<2xi32>
1742- VectorType v2xi32 = VectorType::get (2 , i32 );
1743- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v2xi32, source);
1744-
1749+ Value castedSource =
1750+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17451751 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Bf8Op>(
17461752 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17471753 } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isBF16 ()) {
17481754 // CvtPkScalePk8Bf16Fp4Op
17491755 // i32
17501756 Value castedSource =
1751- LLVM::BitcastOp::create (rewriter, loc, i32 , adaptor. getSource () );
1757+ LLVM::BitcastOp::create (rewriter, loc, packedType, source );
17521758 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp4Op>(
17531759 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17541760 } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isBF16 ()) {
17551761 // CvtPkScalePk8Bf16Fp8Op
17561762 // vector<8xf8E4M3FN>
1757-
1758- // vector<2xi32>
1759- VectorType v2xi32 = VectorType::get (2 , i32 );
1760- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v2xi32, source);
1761-
1763+ Value castedSource =
1764+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17621765 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp8Op>(
17631766 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17641767 } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isBF16 ()) {
17651768 // CvtPkScalePk8Bf16Bf8Op
17661769 // vector<8xf8E5M2>
1767-
1768- // vector<2xi32>
1769- VectorType v2xi32 = VectorType::get (2 , i32 );
1770- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v2xi32, source);
1771-
1770+ Value castedSource =
1771+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17721772 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Bf8Op>(
17731773 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17741774 } else if (isa<Float4E2M1FNType>(srcElemType) and destElemType.isF32 ()) {
17751775 // CvtPkScalePk8F32Fp4Op
17761776 // i32
17771777 Value castedSource =
1778- LLVM::BitcastOp::create (rewriter, loc, i32 , adaptor. getSource () );
1778+ LLVM::BitcastOp::create (rewriter, loc, packedType, source );
17791779 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp4Op>(
17801780 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17811781 } else if (isa<Float8E4M3FNType>(srcElemType) and destElemType.isF32 ()) {
17821782 // CvtPkScalePk8F32Fp8Op
17831783 // vector<8xf8E4M3FN>
1784-
1785- // vector<2xi32>
1786- VectorType v2xi32 = VectorType::get (2 , i32 );
1787- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v2xi32, source);
1788-
1784+ Value castedSource =
1785+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17891786 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp8Op>(
17901787 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17911788 } else if (isa<Float8E5M2Type>(srcElemType) and destElemType.isF32 ()) {
17921789 // CvtPkScalePk8F32Bf8Op
17931790 // vector<8xf8E5M2>
1794-
1795- // vector<2xi32>
1796- VectorType v2xi32 = VectorType::get (2 , i32 );
1797- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v2xi32, source);
1798-
1791+ Value castedSource =
1792+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
17991793 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Bf8Op>(
18001794 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18011795 }
@@ -1808,44 +1802,38 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
18081802 // CvtPkScalePk16${largeT}${smallT}
18091803 else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF16 ()) {
18101804 // CvtPkScale16F16Fp6Op
1811- VectorType v3xi32 = VectorType::get (3 , i32 );
1812- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v3xi32, source);
1813-
1805+ Value castedSource =
1806+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18141807 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Fp6Op>(
18151808 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18161809 } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF16 ()) {
18171810 // CvtPkScale16F16Bf6Op
1818- VectorType v3xi32 = VectorType::get (3 , i32 );
1819- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v3xi32, source);
1820-
1811+ Value castedSource =
1812+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18211813 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Bf6Op>(
18221814 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18231815 } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isBF16 ()) {
18241816 // CvtPkScale16Bf16Fp6Op
1825- VectorType v3xi32 = VectorType::get (3 , i32 );
1826- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v3xi32, source);
1827-
1817+ Value castedSource =
1818+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18281819 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Fp6Op>(
18291820 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18301821 } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isBF16 ()) {
18311822 // CvtPkScale16Bf16Bf6Op
1832- VectorType v3xi32 = VectorType::get (3 , i32 );
1833- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v3xi32, source);
1834-
1823+ Value castedSource =
1824+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18351825 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Bf6Op>(
18361826 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18371827 } else if (isa<Float6E2M3FNType>(srcElemType) and destElemType.isF32 ()) {
18381828 // CvtPkScale16F32Fp6Op
1839- VectorType v3xi32 = VectorType::get (3 , i32 );
1840- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v3xi32, source);
1841-
1829+ Value castedSource =
1830+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18421831 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Fp6Op>(
18431832 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18441833 } else if (isa<Float6E3M2FNType>(srcElemType) and destElemType.isF32 ()) {
18451834 // CvtPkScale16F32Bf6Op
1846- VectorType v3xi32 = VectorType::get (3 , i32 );
1847- Value castedSource = LLVM::BitcastOp::create (rewriter, loc, v3xi32, source);
1848-
1835+ Value castedSource =
1836+ LLVM::BitcastOp::create (rewriter, loc, packedType, source);
18491837 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Bf6Op>(
18501838 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
18511839 } else {
0 commit comments