@@ -16316,7 +16316,12 @@ namespace {
1631616316// apply a combine.
1631716317struct CombineResult;
1631816318
16319- enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
16319+ enum ExtKind : uint8_t {
16320+ ZExt = 1 << 0,
16321+ SExt = 1 << 1,
16322+ FPExt = 1 << 2,
16323+ BF16Ext = 1 << 3
16324+ };
1632016325/// Helper class for folding sign/zero extensions.
1632116326/// In particular, this class is used for the following combines:
1632216327/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16351,8 +16356,10 @@ struct NodeExtensionHelper {
1635116356 /// instance, a splat constant (e.g., 3), would support being both sign and
1635216357 /// zero extended.
1635316358 bool SupportsSExt;
16354- /// Records if this operand is like being floating-Point extended.
16359+ /// Records if this operand is like being floating point extended.
1635516360 bool SupportsFPExt;
16361+ /// Records if this operand is extended from bf16.
16362+ bool SupportsBF16Ext;
1635616363 /// This boolean captures whether we care if this operand would still be
1635716364 /// around after the folding happens.
1635816365 bool EnforceOneUse;
@@ -16388,6 +16395,7 @@ struct NodeExtensionHelper {
1638816395 case ExtKind::ZExt:
1638916396 return RISCVISD::VZEXT_VL;
1639016397 case ExtKind::FPExt:
16398+ case ExtKind::BF16Ext:
1639116399 return RISCVISD::FP_EXTEND_VL;
1639216400 }
1639316401 llvm_unreachable("Unknown ExtKind enum");
@@ -16409,13 +16417,6 @@ struct NodeExtensionHelper {
1640916417 if (Source.getValueType() == NarrowVT)
1641016418 return Source;
1641116419
16412- // vfmadd_vl -> vfwmadd_vl can take bf16 operands
16413- if (Source.getValueType().getVectorElementType() == MVT::bf16) {
16414- assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
16415- Root->getOpcode() == RISCVISD::VFMADD_VL);
16416- return Source;
16417- }
16418-
1641916420 unsigned ExtOpc = getExtOpc(*SupportsExt);
1642016421
1642116422 // If we need an extension, we should be changing the type.
@@ -16458,7 +16459,8 @@ struct NodeExtensionHelper {
1645816459 // Determine the narrow size.
1645916460 unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1646016461
16461- MVT EltVT = SupportsExt == ExtKind::FPExt
16462+ MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
16463+ : SupportsExt == ExtKind::FPExt
1646216464 ? MVT::getFloatingPointVT(NarrowSize)
1646316465 : MVT::getIntegerVT(NarrowSize);
1646416466
@@ -16635,17 +16637,13 @@ struct NodeExtensionHelper {
1663516637 EnforceOneUse = false;
1663616638 }
1663716639
16638- bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
16639- const RISCVSubtarget &Subtarget) {
16640- // Any f16 extension will need zvfh
16641- if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
16642- return false;
16643- // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
16644- // zvfbfwma
16645- if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
16646- Root->getOpcode() != RISCVISD::VFMADD_VL))
16647- return false;
16648- return true;
16640+ bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16641+ return (NarrowEltVT == MVT::f32 ||
16642+ (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16()));
16643+ }
16644+
16645+ bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16646+ return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
1664916647 }
1665016648
1665116649 /// Helper method to set the various fields of this struct based on the
@@ -16655,6 +16653,7 @@ struct NodeExtensionHelper {
1665516653 SupportsZExt = false;
1665616654 SupportsSExt = false;
1665716655 SupportsFPExt = false;
16656+ SupportsBF16Ext = false;
1665816657 EnforceOneUse = true;
1665916658 unsigned Opc = OrigOperand.getOpcode();
1666016659 // For the nodes we handle below, we end up using their inputs directly: see
@@ -16686,9 +16685,11 @@ struct NodeExtensionHelper {
1668616685 case RISCVISD::FP_EXTEND_VL: {
1668716686 MVT NarrowEltVT =
1668816687 OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
16689- if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
16690- break;
16691- SupportsFPExt = true;
16688+ if (isSupportedFPExtend(NarrowEltVT, Subtarget))
16689+ SupportsFPExt = true;
16690+ if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
16691+ SupportsBF16Ext = true;
16692+
1669216693 break;
1669316694 }
1669416695 case ISD::SPLAT_VECTOR:
@@ -16705,16 +16706,16 @@ struct NodeExtensionHelper {
1670516706 if (Op.getOpcode() != ISD::FP_EXTEND)
1670616707 break;
1670716708
16708- if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
16709- Subtarget))
16710- break;
16711-
1671216709 unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1671316710 unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
1671416711 if (NarrowSize != ScalarBits)
1671516712 break;
1671616713
16717- SupportsFPExt = true;
16714+ if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
16715+ SupportsFPExt = true;
16716+ if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
16717+ Subtarget))
16718+ SupportsBF16Ext = true;
1671816719 break;
1671916720 }
1672016721 default:
@@ -16947,6 +16948,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1694716948 return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
1694816949 Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
1694916950 /*RHSExt=*/{ExtKind::FPExt});
16951+ if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
16952+ RHS.SupportsBF16Ext)
16953+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
16954+ Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
16955+ /*RHSExt=*/{ExtKind::BF16Ext});
1695016956 return std::nullopt;
1695116957}
1695216958
@@ -17029,6 +17035,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1702917035 Subtarget);
1703017036}
1703117037
17038+ /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17039+ ///
17040+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17041+ /// can be used to apply the pattern.
17042+ static std::optional<CombineResult>
17043+ canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17044+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17045+ const RISCVSubtarget &Subtarget) {
17046+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17047+ Subtarget);
17048+ }
17049+
1703217050/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
1703317051///
1703417052/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17068,6 +17086,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1706817086 case RISCVISD::VFNMADD_VL:
1706917087 case RISCVISD::VFNMSUB_VL:
1707017088 Strategies.push_back(canFoldToVWWithSameExtension);
17089+ if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17090+ Strategies.push_back(canFoldToVWWithBF16EXT);
1707117091 break;
1707217092 case ISD::MUL:
1707317093 case RISCVISD::MUL_VL:
0 commit comments