Skip to content

Commit c2f0ce3

Browse files
committed
[RISCV][WIP] Treat bf16->f32 as separate ExtKind in combineOp_VLToVWOp_VL.
This allows us to better track the narrow type we need and to fix miscompiles if f16->f32 and bf16->f32 extends are mixed. Fixes #144651. Still need to add tests, but it's late and I need sleep.
1 parent ad9e591 commit c2f0ce3

File tree

1 file changed

+42
-31
lines changed

1 file changed

+42
-31
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16309,7 +16309,12 @@ namespace {
1630916309
// apply a combine.
1631016310
struct CombineResult;
1631116311

16312-
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
16312+
enum ExtKind : uint8_t {
16313+
ZExt = 1 << 0,
16314+
SExt = 1 << 1,
16315+
FPExt = 1 << 2,
16316+
BF16Ext = 1 << 3
16317+
};
1631316318
/// Helper class for folding sign/zero extensions.
1631416319
/// In particular, this class is used for the following combines:
1631516320
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16344,8 +16349,10 @@ struct NodeExtensionHelper {
1634416349
/// instance, a splat constant (e.g., 3), would support being both sign and
1634516350
/// zero extended.
1634616351
bool SupportsSExt;
16347-
/// Records if this operand is like being floating-Point extended.
16352+
/// Records if this operand is like being floating point extended.
1634816353
bool SupportsFPExt;
16354+
/// Records if this operand is extended from bf16.
16355+
bool SupportsBF16Ext;
1634916356
/// This boolean captures whether we care if this operand would still be
1635016357
/// around after the folding happens.
1635116358
bool EnforceOneUse;
@@ -16381,6 +16388,7 @@ struct NodeExtensionHelper {
1638116388
case ExtKind::ZExt:
1638216389
return RISCVISD::VZEXT_VL;
1638316390
case ExtKind::FPExt:
16391+
case ExtKind::BF16Ext:
1638416392
return RISCVISD::FP_EXTEND_VL;
1638516393
}
1638616394
llvm_unreachable("Unknown ExtKind enum");
@@ -16402,13 +16410,6 @@ struct NodeExtensionHelper {
1640216410
if (Source.getValueType() == NarrowVT)
1640316411
return Source;
1640416412

16405-
// vfmadd_vl -> vfwmadd_vl can take bf16 operands
16406-
if (Source.getValueType().getVectorElementType() == MVT::bf16) {
16407-
assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
16408-
Root->getOpcode() == RISCVISD::VFMADD_VL);
16409-
return Source;
16410-
}
16411-
1641216413
unsigned ExtOpc = getExtOpc(*SupportsExt);
1641316414

1641416415
// If we need an extension, we should be changing the type.
@@ -16451,7 +16452,8 @@ struct NodeExtensionHelper {
1645116452
// Determine the narrow size.
1645216453
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1645316454

16454-
MVT EltVT = SupportsExt == ExtKind::FPExt
16455+
MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
16456+
: SupportsExt == ExtKind::FPExt
1645516457
? MVT::getFloatingPointVT(NarrowSize)
1645616458
: MVT::getIntegerVT(NarrowSize);
1645716459

@@ -16628,17 +16630,17 @@ struct NodeExtensionHelper {
1662816630
EnforceOneUse = false;
1662916631
}
1663016632

16631-
bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
16632-
const RISCVSubtarget &Subtarget) {
16633+
bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16634+
if (NarrowEltVT == MVT::f32)
16635+
return true;
1663316636
// Any f16 extension will need zvfh
16634-
if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
16635-
return false;
16636-
// The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
16637-
// zvfbfwma
16638-
if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
16639-
Root->getOpcode() != RISCVISD::VFMADD_VL))
16640-
return false;
16641-
return true;
16637+
if (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16())
16638+
return true;
16639+
return false;
16640+
}
16641+
16642+
bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16643+
return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
1664216644
}
1664316645

1664416646
/// Helper method to set the various fields of this struct based on the
@@ -16648,6 +16650,7 @@ struct NodeExtensionHelper {
1664816650
SupportsZExt = false;
1664916651
SupportsSExt = false;
1665016652
SupportsFPExt = false;
16653+
SupportsBF16Ext = false;
1665116654
EnforceOneUse = true;
1665216655
unsigned Opc = OrigOperand.getOpcode();
1665316656
// For the nodes we handle below, we end up using their inputs directly: see
@@ -16679,9 +16682,11 @@ struct NodeExtensionHelper {
1667916682
case RISCVISD::FP_EXTEND_VL: {
1668016683
MVT NarrowEltVT =
1668116684
OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
16682-
if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
16683-
break;
16684-
SupportsFPExt = true;
16685+
if (isSupportedFPExtend(NarrowEltVT, Subtarget))
16686+
SupportsFPExt = true;
16687+
if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
16688+
SupportsBF16Ext = true;
16689+
1668516690
break;
1668616691
}
1668716692
case ISD::SPLAT_VECTOR:
@@ -16698,16 +16703,16 @@ struct NodeExtensionHelper {
1669816703
if (Op.getOpcode() != ISD::FP_EXTEND)
1669916704
break;
1670016705

16701-
if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
16702-
Subtarget))
16703-
break;
16704-
1670516706
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1670616707
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
1670716708
if (NarrowSize != ScalarBits)
1670816709
break;
1670916710

16710-
SupportsFPExt = true;
16711+
if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
16712+
SupportsFPExt = true;
16713+
if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
16714+
Subtarget))
16715+
SupportsBF16Ext = true;
1671116716
break;
1671216717
}
1671316718
default:
@@ -16940,6 +16945,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1694016945
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
1694116946
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
1694216947
/*RHSExt=*/{ExtKind::FPExt});
16948+
if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
16949+
RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL)
16950+
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
16951+
Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
16952+
/*RHSExt=*/{ExtKind::BF16Ext});
1694316953
return std::nullopt;
1694416954
}
1694516955

@@ -16953,9 +16963,10 @@ static std::optional<CombineResult>
1695316963
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1695416964
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1695516965
const RISCVSubtarget &Subtarget) {
16956-
return canFoldToVWWithSameExtensionImpl(
16957-
Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
16958-
Subtarget);
16966+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS,
16967+
ExtKind::ZExt | ExtKind::SExt |
16968+
ExtKind::FPExt | ExtKind::BF16Ext,
16969+
DAG, Subtarget);
1695916970
}
1696016971

1696116972
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))

0 commit comments

Comments
 (0)