Skip to content

Commit b88f7f6

Browse files
committed
Use operation name
1 parent 34ed3e9 commit b88f7f6

File tree

1 file changed

+57
-48
lines changed

1 file changed

+57
-48
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,46 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
16471647
return bits;
16481648
}
16491649

1650+
static std::optional<StringRef>
1651+
scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1652+
using fp4 = Float4E2M1FNType;
1653+
using fp8 = Float8E4M3FNType;
1654+
using bf8 = Float8E5M2Type;
1655+
using fp6 = Float6E2M3FNType;
1656+
using bf6 = Float6E3M2FNType;
1657+
if (isa<fp4>(srcElemType) && destElemType.isF16())
1658+
return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1659+
if (isa<fp8>(srcElemType) && destElemType.isF16())
1660+
return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1661+
if (isa<bf8>(srcElemType) && destElemType.isF16())
1662+
return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1663+
if (isa<fp4>(srcElemType) && destElemType.isBF16())
1664+
return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1665+
if (isa<fp8>(srcElemType) && destElemType.isBF16())
1666+
return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1667+
if (isa<bf8>(srcElemType) && destElemType.isBF16())
1668+
return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1669+
if (isa<fp4>(srcElemType) && destElemType.isF32())
1670+
return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1671+
if (isa<fp8>(srcElemType) && destElemType.isF32())
1672+
return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1673+
if (isa<bf8>(srcElemType) && destElemType.isF32())
1674+
return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1675+
if (isa<fp6>(srcElemType) && destElemType.isF16())
1676+
return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1677+
if (isa<bf6>(srcElemType) && destElemType.isF16())
1678+
return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1679+
if (isa<fp6>(srcElemType) && destElemType.isBF16())
1680+
return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1681+
if (isa<bf6>(srcElemType) && destElemType.isBF16())
1682+
return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1683+
if (isa<fp6>(srcElemType) && destElemType.isF32())
1684+
return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1685+
if (isa<bf6>(srcElemType) && destElemType.isF32())
1686+
return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1687+
return std::nullopt;
1688+
}
1689+
16501690
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
16511691
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
16521692
ConversionPatternRewriter &rewriter) const {
@@ -1694,54 +1734,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
16941734
Value castedSource =
16951735
LLVM::BitcastOp::create(rewriter, loc, packedType, source);
16961736

1697-
if (isa<fp4>(srcElemType) && destElemType.isF16()) {
1698-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
1699-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1700-
} else if (isa<fp8>(srcElemType) && destElemType.isF16()) {
1701-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
1702-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1703-
} else if (isa<bf8>(srcElemType) && destElemType.isF16()) {
1704-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
1705-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1706-
} else if (isa<fp4>(srcElemType) && destElemType.isBF16()) {
1707-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
1708-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1709-
} else if (isa<fp8>(srcElemType) && destElemType.isBF16()) {
1710-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
1711-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1712-
} else if (isa<bf8>(srcElemType) && destElemType.isBF16()) {
1713-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
1714-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1715-
} else if (isa<fp4>(srcElemType) && destElemType.isF32()) {
1716-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
1717-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1718-
} else if (isa<fp8>(srcElemType) && destElemType.isF32()) {
1719-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
1720-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1721-
} else if (isa<bf8>(srcElemType) && destElemType.isF32()) {
1722-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
1723-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1724-
} else if (isa<fp6>(srcElemType) && destElemType.isF16()) {
1725-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
1726-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1727-
} else if (isa<bf6>(srcElemType) && destElemType.isF16()) {
1728-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
1729-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1730-
} else if (isa<fp6>(srcElemType) && destElemType.isBF16()) {
1731-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
1732-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1733-
} else if (isa<bf6>(srcElemType) && destElemType.isBF16()) {
1734-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
1735-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1736-
} else if (isa<fp6>(srcElemType) && destElemType.isF32()) {
1737-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
1738-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1739-
} else if (isa<bf6>(srcElemType) && destElemType.isF32()) {
1740-
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
1741-
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
1742-
} else {
1743-
return failure();
1744-
}
1737+
std::optional<StringRef> maybeIntrinsic =
1738+
scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
1739+
if (!maybeIntrinsic.has_value())
1740+
return op.emitOpError(
1741+
"no intrinsic matching packed scaled conversion on the given chipset");
1742+
1743+
OperationState loweredOp(loc, *maybeIntrinsic);
1744+
loweredOp.addTypes({op.getResult().getType()});
1745+
loweredOp.addOperands({castedSource, castedScale});
1746+
1747+
SmallVector<NamedAttribute, 1> attrs;
1748+
attrs.push_back(
1749+
NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
1750+
1751+
loweredOp.addAttributes(attrs);
1752+
Operation *lowered = rewriter.create(loweredOp);
1753+
rewriter.replaceOp(op, lowered);
17451754

17461755
return success();
17471756
}

0 commit comments

Comments
 (0)