Skip to content

Commit 7c44f09

Browse files
committed
Add top-level if condition for each src type
1 parent a3db728 commit 7c44f09

File tree

1 file changed

+47
-31
lines changed

1 file changed

+47
-31
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,37 +1667,53 @@ scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
16671667
using bf8 = Float8E5M2Type;
16681668
using fp6 = Float6E2M3FNType;
16691669
using bf6 = Float6E3M2FNType;
1670-
if (isa<fp4>(srcElemType) && destElemType.isF16())
1671-
return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1672-
if (isa<fp8>(srcElemType) && destElemType.isF16())
1673-
return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1674-
if (isa<bf8>(srcElemType) && destElemType.isF16())
1675-
return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1676-
if (isa<fp4>(srcElemType) && destElemType.isBF16())
1677-
return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1678-
if (isa<fp8>(srcElemType) && destElemType.isBF16())
1679-
return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1680-
if (isa<bf8>(srcElemType) && destElemType.isBF16())
1681-
return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1682-
if (isa<fp4>(srcElemType) && destElemType.isF32())
1683-
return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1684-
if (isa<fp8>(srcElemType) && destElemType.isF32())
1685-
return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1686-
if (isa<bf8>(srcElemType) && destElemType.isF32())
1687-
return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1688-
if (isa<fp6>(srcElemType) && destElemType.isF16())
1689-
return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1690-
if (isa<bf6>(srcElemType) && destElemType.isF16())
1691-
return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1692-
if (isa<fp6>(srcElemType) && destElemType.isBF16())
1693-
return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1694-
if (isa<bf6>(srcElemType) && destElemType.isBF16())
1695-
return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1696-
if (isa<fp6>(srcElemType) && destElemType.isF32())
1697-
return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1698-
if (isa<bf6>(srcElemType) && destElemType.isF32())
1699-
return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1700-
return std::nullopt;
1670+
if (isa<fp4>(srcElemType)) {
1671+
if (destElemType.isF16())
1672+
return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1673+
if (destElemType.isBF16())
1674+
return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1675+
if (destElemType.isF32())
1676+
return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1677+
return std::nullopt;
1678+
}
1679+
if (isa<fp8>(srcElemType)) {
1680+
if (destElemType.isF16())
1681+
return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1682+
if (destElemType.isBF16())
1683+
return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1684+
if (destElemType.isF32())
1685+
return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1686+
return std::nullopt;
1687+
}
1688+
if (isa<bf8>(srcElemType)) {
1689+
if (destElemType.isF16())
1690+
return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1691+
if (destElemType.isBF16())
1692+
return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1693+
if (destElemType.isF32())
1694+
return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1695+
return std::nullopt;
1696+
}
1697+
if (isa<fp6>(srcElemType)) {
1698+
if (destElemType.isF16())
1699+
return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1700+
if (destElemType.isBF16())
1701+
return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1702+
if (destElemType.isF32())
1703+
return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1704+
return std::nullopt;
1705+
}
1706+
if (isa<bf6>(srcElemType)) {
1707+
if (destElemType.isF16())
1708+
return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1709+
if (destElemType.isBF16())
1710+
return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1711+
if (destElemType.isF32())
1712+
return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1713+
return std::nullopt;
1714+
}
1715+
llvm_unreachable("invalid combination of element types for packed conversion "
1716+
"instructions");
17011717
}
17021718

17031719
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(

0 commit comments

Comments
 (0)