@@ -1647,6 +1647,46 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
16471647 return bits;
16481648}
16491649
1650+ static std::optional<StringRef>
1651+ scaledExtPacked816ToIntrinsic (Type srcElemType, Type destElemType) {
1652+ using fp4 = Float4E2M1FNType;
1653+ using fp8 = Float8E4M3FNType;
1654+ using bf8 = Float8E5M2Type;
1655+ using fp6 = Float6E2M3FNType;
1656+ using bf6 = Float6E3M2FNType;
1657+ if (isa<fp4>(srcElemType) && destElemType.isF16 ())
1658+ return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName ();
1659+ if (isa<fp8>(srcElemType) && destElemType.isF16 ())
1660+ return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName ();
1661+ if (isa<bf8>(srcElemType) && destElemType.isF16 ())
1662+ return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName ();
1663+ if (isa<fp4>(srcElemType) && destElemType.isBF16 ())
1664+ return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName ();
1665+ if (isa<fp8>(srcElemType) && destElemType.isBF16 ())
1666+ return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName ();
1667+ if (isa<bf8>(srcElemType) && destElemType.isBF16 ())
1668+ return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName ();
1669+ if (isa<fp4>(srcElemType) && destElemType.isF32 ())
1670+ return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName ();
1671+ if (isa<fp8>(srcElemType) && destElemType.isF32 ())
1672+ return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName ();
1673+ if (isa<bf8>(srcElemType) && destElemType.isF32 ())
1674+ return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName ();
1675+ if (isa<fp6>(srcElemType) && destElemType.isF16 ())
1676+ return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName ();
1677+ if (isa<bf6>(srcElemType) && destElemType.isF16 ())
1678+ return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName ();
1679+ if (isa<fp6>(srcElemType) && destElemType.isBF16 ())
1680+ return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName ();
1681+ if (isa<bf6>(srcElemType) && destElemType.isBF16 ())
1682+ return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName ();
1683+ if (isa<fp6>(srcElemType) && destElemType.isF32 ())
1684+ return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName ();
1685+ if (isa<bf6>(srcElemType) && destElemType.isF32 ())
1686+ return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName ();
1687+ return std::nullopt ;
1688+ }
1689+
16501690LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite (
16511691 ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
16521692 ConversionPatternRewriter &rewriter) const {
@@ -1694,54 +1734,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
16941734 Value castedSource =
16951735 LLVM::BitcastOp::create (rewriter, loc, packedType, source);
16961736
1697- if (isa<fp4>(srcElemType) && destElemType.isF16 ()) {
1698- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp4Op>(
1699- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1700- } else if (isa<fp8>(srcElemType) && destElemType.isF16 ()) {
1701- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Fp8Op>(
1702- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1703- } else if (isa<bf8>(srcElemType) && destElemType.isF16 ()) {
1704- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F16Bf8Op>(
1705- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1706- } else if (isa<fp4>(srcElemType) && destElemType.isBF16 ()) {
1707- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp4Op>(
1708- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1709- } else if (isa<fp8>(srcElemType) && destElemType.isBF16 ()) {
1710- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Fp8Op>(
1711- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1712- } else if (isa<bf8>(srcElemType) && destElemType.isBF16 ()) {
1713- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8Bf16Bf8Op>(
1714- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1715- } else if (isa<fp4>(srcElemType) && destElemType.isF32 ()) {
1716- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp4Op>(
1717- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1718- } else if (isa<fp8>(srcElemType) && destElemType.isF32 ()) {
1719- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Fp8Op>(
1720- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1721- } else if (isa<bf8>(srcElemType) && destElemType.isF32 ()) {
1722- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk8F32Bf8Op>(
1723- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1724- } else if (isa<fp6>(srcElemType) && destElemType.isF16 ()) {
1725- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Fp6Op>(
1726- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1727- } else if (isa<bf6>(srcElemType) && destElemType.isF16 ()) {
1728- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F16Bf6Op>(
1729- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1730- } else if (isa<fp6>(srcElemType) && destElemType.isBF16 ()) {
1731- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Fp6Op>(
1732- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1733- } else if (isa<bf6>(srcElemType) && destElemType.isBF16 ()) {
1734- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16Bf16Bf6Op>(
1735- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1736- } else if (isa<fp6>(srcElemType) && destElemType.isF32 ()) {
1737- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Fp6Op>(
1738- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1739- } else if (isa<bf6>(srcElemType) && destElemType.isF32 ()) {
1740- rewriter.replaceOpWithNewOp <ROCDL::CvtPkScalePk16F32Bf6Op>(
1741- op, op.getResult ().getType (), castedSource, castedScale, scaleSel);
1742- } else {
1743- return failure ();
1744- }
1737+ std::optional<StringRef> maybeIntrinsic =
1738+ scaledExtPacked816ToIntrinsic (srcElemType, destElemType);
1739+ if (!maybeIntrinsic.has_value ())
1740+ return op.emitOpError (
1741+ " no intrinsic matching packed scaled conversion on the given chipset" );
1742+
1743+ OperationState loweredOp (loc, *maybeIntrinsic);
1744+ loweredOp.addTypes ({op.getResult ().getType ()});
1745+ loweredOp.addOperands ({castedSource, castedScale});
1746+
1747+ SmallVector<NamedAttribute, 1 > attrs;
1748+ attrs.push_back (
1749+ NamedAttribute (" scaleSel" , rewriter.getI32IntegerAttr (scaleSel)));
1750+
1751+ loweredOp.addAttributes (attrs);
1752+ Operation *lowered = rewriter.create (loweredOp);
1753+ rewriter.replaceOp (op, lowered);
17451754
17461755 return success ();
17471756}
0 commit comments