@@ -16309,7 +16309,12 @@ namespace {
1630916309// apply a combine.
1631016310struct CombineResult;
1631116311
16312- enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
16312+ enum ExtKind : uint8_t {
16313+ ZExt = 1 << 0,
16314+ SExt = 1 << 1,
16315+ FPExt = 1 << 2,
16316+ BF16Ext = 1 << 3
16317+ };
1631316318/// Helper class for folding sign/zero extensions.
1631416319/// In particular, this class is used for the following combines:
1631516320/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16344,8 +16349,10 @@ struct NodeExtensionHelper {
1634416349 /// instance, a splat constant (e.g., 3), would support being both sign and
1634516350 /// zero extended.
1634616351 bool SupportsSExt;
16347- /// Records if this operand is like being floating-Point extended.
16352+ /// Records if this operand is like being floating point extended.
1634816353 bool SupportsFPExt;
16354+ /// Records if this operand is extended from bf16.
16355+ bool SupportsBF16Ext;
1634916356 /// This boolean captures whether we care if this operand would still be
1635016357 /// around after the folding happens.
1635116358 bool EnforceOneUse;
@@ -16381,6 +16388,7 @@ struct NodeExtensionHelper {
1638116388 case ExtKind::ZExt:
1638216389 return RISCVISD::VZEXT_VL;
1638316390 case ExtKind::FPExt:
16391+ case ExtKind::BF16Ext:
1638416392 return RISCVISD::FP_EXTEND_VL;
1638516393 }
1638616394 llvm_unreachable("Unknown ExtKind enum");
@@ -16402,13 +16410,6 @@ struct NodeExtensionHelper {
1640216410 if (Source.getValueType() == NarrowVT)
1640316411 return Source;
1640416412
16405- // vfmadd_vl -> vfwmadd_vl can take bf16 operands
16406- if (Source.getValueType().getVectorElementType() == MVT::bf16) {
16407- assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
16408- Root->getOpcode() == RISCVISD::VFMADD_VL);
16409- return Source;
16410- }
16411-
1641216413 unsigned ExtOpc = getExtOpc(*SupportsExt);
1641316414
1641416415 // If we need an extension, we should be changing the type.
@@ -16451,7 +16452,8 @@ struct NodeExtensionHelper {
1645116452 // Determine the narrow size.
1645216453 unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1645316454
16454- MVT EltVT = SupportsExt == ExtKind::FPExt
16455+ MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
16456+ : SupportsExt == ExtKind::FPExt
1645516457 ? MVT::getFloatingPointVT(NarrowSize)
1645616458 : MVT::getIntegerVT(NarrowSize);
1645716459
@@ -16628,17 +16630,17 @@ struct NodeExtensionHelper {
1662816630 EnforceOneUse = false;
1662916631 }
1663016632
16631- bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
16632- const RISCVSubtarget &Subtarget) {
16633+ bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16634+ if (NarrowEltVT == MVT::f32)
16635+ return true;
1663316636 // Any f16 extension will need zvfh
16634- if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
16635- return false;
16636- // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
16637- // zvfbfwma
16638- if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
16639- Root->getOpcode() != RISCVISD::VFMADD_VL))
16640- return false;
16641- return true;
16637+ if (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16())
16638+ return true;
16639+ return false;
16640+ }
16641+
16642+ bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16643+ return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
1664216644 }
1664316645
1664416646 /// Helper method to set the various fields of this struct based on the
@@ -16648,6 +16650,7 @@ struct NodeExtensionHelper {
1664816650 SupportsZExt = false;
1664916651 SupportsSExt = false;
1665016652 SupportsFPExt = false;
16653+ SupportsBF16Ext = false;
1665116654 EnforceOneUse = true;
1665216655 unsigned Opc = OrigOperand.getOpcode();
1665316656 // For the nodes we handle below, we end up using their inputs directly: see
@@ -16679,9 +16682,11 @@ struct NodeExtensionHelper {
1667916682 case RISCVISD::FP_EXTEND_VL: {
1668016683 MVT NarrowEltVT =
1668116684 OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
16682- if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
16683- break;
16684- SupportsFPExt = true;
16685+ if (isSupportedFPExtend(NarrowEltVT, Subtarget))
16686+ SupportsFPExt = true;
16687+ if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
16688+ SupportsBF16Ext = true;
16689+
1668516690 break;
1668616691 }
1668716692 case ISD::SPLAT_VECTOR:
@@ -16698,16 +16703,16 @@ struct NodeExtensionHelper {
1669816703 if (Op.getOpcode() != ISD::FP_EXTEND)
1669916704 break;
1670016705
16701- if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
16702- Subtarget))
16703- break;
16704-
1670516706 unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1670616707 unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
1670716708 if (NarrowSize != ScalarBits)
1670816709 break;
1670916710
16710- SupportsFPExt = true;
16711+ if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
16712+ SupportsFPExt = true;
16713+ if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
16714+ Subtarget))
16715+ SupportsBF16Ext = true;
1671116716 break;
1671216717 }
1671316718 default:
@@ -16940,6 +16945,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1694016945 return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
1694116946 Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
1694216947 /*RHSExt=*/{ExtKind::FPExt});
16948+ if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
16949+ RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL)
16950+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
16951+ Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
16952+ /*RHSExt=*/{ExtKind::BF16Ext});
1694316953 return std::nullopt;
1694416954}
1694516955
@@ -16953,9 +16963,10 @@ static std::optional<CombineResult>
1695316963canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1695416964 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1695516965 const RISCVSubtarget &Subtarget) {
16956- return canFoldToVWWithSameExtensionImpl(
16957- Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
16958- Subtarget);
16966+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS,
16967+ ExtKind::ZExt | ExtKind::SExt |
16968+ ExtKind::FPExt | ExtKind::BF16Ext,
16969+ DAG, Subtarget);
1695916970}
1696016971
1696116972/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
0 commit comments