Skip to content

Commit 274c3f1

Browse files
committed
fixup! Use a separate strategy for this in getSupportedFoldings.
This avoids a root opcode check elsewhere.
1 parent aafd72e commit 274c3f1

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16946,7 +16946,7 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1694616946
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
1694716947
/*RHSExt=*/{ExtKind::FPExt});
1694816948
if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
16949-
RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL)
16949+
RHS.SupportsBF16Ext)
1695016950
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
1695116951
Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
1695216952
/*RHSExt=*/{ExtKind::BF16Ext});
@@ -16963,10 +16963,9 @@ static std::optional<CombineResult>
1696316963
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1696416964
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1696516965
const RISCVSubtarget &Subtarget) {
16966-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS,
16967-
ExtKind::ZExt | ExtKind::SExt |
16968-
ExtKind::FPExt | ExtKind::BF16Ext,
16969-
DAG, Subtarget);
16966+
return canFoldToVWWithSameExtensionImpl(
16967+
Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
16968+
Subtarget);
1697016969
}
1697116970

1697216971
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -17033,6 +17032,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1703317032
Subtarget);
1703417033
}
1703517034

17035+
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17036+
///
17037+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17038+
/// can be used to apply the pattern.
17039+
static std::optional<CombineResult>
17040+
canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17041+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17042+
const RISCVSubtarget &Subtarget) {
17043+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17044+
Subtarget);
17045+
}
17046+
1703617047
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
1703717048
///
1703817049
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17072,6 +17083,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1707217083
case RISCVISD::VFNMADD_VL:
1707317084
case RISCVISD::VFNMSUB_VL:
1707417085
Strategies.push_back(canFoldToVWWithSameExtension);
17086+
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17087+
Strategies.push_back(canFoldToVWWithBF16EXT);
1707517088
break;
1707617089
case ISD::MUL:
1707717090
case RISCVISD::MUL_VL:

0 commit comments

Comments
 (0)