@@ -1620,7 +1620,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
16201620 // firstScaleByte are merged into a single attribute scaleSel. This is how
16211621 // those values are merged together.
16221622 assert (llvm::is_contained ({16 , 32 }, blockSize));
1623- assert (llvm::is_contained ({4 , 6 , 8 }, bitWidth));
1623+ assert (llvm::is_contained (::llvm::ArrayRef< unsigned > {4 , 6 , 8 }, bitWidth));
16241624
16251625 const bool is_fp8 = bitWidth == 8 ;
16261626 const bool is_block_16 = blockSize == 16 ;
@@ -1653,6 +1653,11 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
16531653LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite (
16541654 ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
16551655 ConversionPatternRewriter &rewriter) const {
1656+ using fp4 = Float4E2M1FNType;
1657+ using fp8 = Float8E4M3FNType;
1658+ using bf8 = Float8E5M2Type;
1659+ using fp6 = Float6E2M3FNType;
1660+ using bf6 = Float6E3M2FNType;
16561661 int32_t firstScaleLane = op.getFirstScaleLane ();
16571662 int32_t firstScaleByte = op.getFirstScaleByte ();
16581663 int32_t blockSize = op.getBlockSize ();
@@ -1671,79 +1676,64 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
16711676
16721677 Value source = adaptor.getSource ();
16731678 Type packedType;
1674- if (isa<Float4E2M1FNType >(srcElemType)) {
1679+ if (isa<fp4 >(srcElemType)) {
16751680 packedType = i32 ;
16761681 packedType = getTypeConverter ()->convertType (packedType);
1677- } else if (isa<Float8E4M3FNType>(srcElemType) ||
1678- isa<Float8E5M2Type>(srcElemType)) {
1682+ } else if (isa<fp8, bf8>(srcElemType)) {
16791683 packedType = VectorType::get (2 , i32 );
16801684 packedType = getTypeConverter ()->convertType (packedType);
1681- } else if (isa<Float6E2M3FNType>(srcElemType) ||
1682- isa<Float6E3M2FNType>(srcElemType)) {
1685+ } else if (isa<fp6, bf6>(srcElemType)) {
16831686 packedType = VectorType::get (3 , i32 );
16841687 packedType = getTypeConverter ()->convertType (packedType);
16851688 } else {
16861689 llvm_unreachable (" invalid element type for scaled ext" );
16871690 }
1688- // smallT = [Fp4, Fp8, Bf8]
1689- // Bf8 = E5M2
1690- // Fp8 = E4M3
1691- //
1692- // largeT = [F16, Bf16, F32]
1693- // CvtPkScalePk8${largeT}${smallT}
16941691 Value castedSource =
16951692 LLVM::BitcastOp::create (rewriter, loc, packedType, source);
16961693
1697- if (isa<Float4E2M1FNType >(srcElemType) && destElemType.isF16 ()) {
1694+ if (isa<fp4 >(srcElemType) && destElemType.isF16 ()) {
16981695 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp4Op>(
16991696 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1700- } else if (isa<Float8E4M3FNType >(srcElemType) && destElemType.isF16 ()) {
1697+ } else if (isa<fp8 >(srcElemType) && destElemType.isF16 ()) {
17011698 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp8Op>(
17021699 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1703- } else if (isa<Float8E5M2Type >(srcElemType) && destElemType.isF16 ()) {
1700+ } else if (isa<bf8 >(srcElemType) && destElemType.isF16 ()) {
17041701 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Bf8Op>(
17051702 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1706- } else if (isa<Float4E2M1FNType >(srcElemType) && destElemType.isBF16 ()) {
1703+ } else if (isa<fp4 >(srcElemType) && destElemType.isBF16 ()) {
17071704 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp4Op>(
17081705 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1709- } else if (isa<Float8E4M3FNType >(srcElemType) && destElemType.isBF16 ()) {
1706+ } else if (isa<fp8 >(srcElemType) && destElemType.isBF16 ()) {
17101707 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp8Op>(
17111708 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1712- } else if (isa<Float8E5M2Type >(srcElemType) && destElemType.isBF16 ()) {
1709+ } else if (isa<bf8 >(srcElemType) && destElemType.isBF16 ()) {
17131710 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Bf8Op>(
17141711 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1715- } else if (isa<Float4E2M1FNType >(srcElemType) && destElemType.isF32 ()) {
1712+ } else if (isa<fp4 >(srcElemType) && destElemType.isF32 ()) {
17161713 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp4Op>(
17171714 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1718- } else if (isa<Float8E4M3FNType >(srcElemType) && destElemType.isF32 ()) {
1715+ } else if (isa<fp8 >(srcElemType) && destElemType.isF32 ()) {
17191716 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp8Op>(
17201717 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1721- } else if (isa<Float8E5M2Type >(srcElemType) && destElemType.isF32 ()) {
1718+ } else if (isa<bf8 >(srcElemType) && destElemType.isF32 ()) {
17221719 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Bf8Op>(
17231720 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1724- }
1725- // smallT = [Fp6, Bf6]
1726- // Fp6 = Float6E2M3FN
1727- // Bf6 = Float6E3M2FN
1728- // largeT = [F16, Bf16, F32]
1729- //
1730- // CvtPkScalePk16${largeT}${smallT}
1731- else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16 ()) {
1721+ } else if (isa<fp6>(srcElemType) && destElemType.isF16 ()) {
17321722 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Fp6Op>(
17331723 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1734- } else if (isa<Float6E3M2FNType >(srcElemType) && destElemType.isF16 ()) {
1724+ } else if (isa<bf6 >(srcElemType) && destElemType.isF16 ()) {
17351725 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Bf6Op>(
17361726 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1737- } else if (isa<Float6E2M3FNType >(srcElemType) && destElemType.isBF16 ()) {
1727+ } else if (isa<fp6 >(srcElemType) && destElemType.isBF16 ()) {
17381728 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Fp6Op>(
17391729 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1740- } else if (isa<Float6E3M2FNType >(srcElemType) && destElemType.isBF16 ()) {
1730+ } else if (isa<bf6 >(srcElemType) && destElemType.isBF16 ()) {
17411731 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Bf6Op>(
17421732 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1743- } else if (isa<Float6E2M3FNType >(srcElemType) && destElemType.isF32 ()) {
1733+ } else if (isa<fp6 >(srcElemType) && destElemType.isF32 ()) {
17441734 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Fp6Op>(
17451735 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1746- } else if (isa<Float6E3M2FNType >(srcElemType) && destElemType.isF32 ()) {
1736+ } else if (isa<bf6 >(srcElemType) && destElemType.isF32 ()) {
17471737 rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Bf6Op>(
17481738 op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
17491739 } else {
0 commit comments