Skip to content

Commit a156475

Browse files
committed
[RISCV] Re-work how VWADD_W_VL and similar _W_VL nodes are handled in combineOp_VLToVWOp_VL.
These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension. This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))). To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source. Fixes #159152.
1 parent 9945af0 commit a156475

File tree

2 files changed

+51
-39
lines changed

2 files changed

+51
-39
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17325,18 +17325,9 @@ struct NodeExtensionHelper {
1732517325
case RISCVISD::VWSUBU_W_VL:
1732617326
case RISCVISD::VFWADD_W_VL:
1732717327
case RISCVISD::VFWSUB_W_VL:
17328-
if (OperandIdx == 1) {
17329-
SupportsZExt =
17330-
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
17331-
SupportsSExt =
17332-
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
17333-
SupportsFPExt =
17334-
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
17335-
// There's no existing extension here, so we don't have to worry about
17336-
// making sure it gets removed.
17337-
EnforceOneUse = false;
17328+
// Operand 1 can't be changed.
17329+
if (OperandIdx == 1)
1733817330
break;
17339-
}
1734017331
[[fallthrough]];
1734117332
default:
1734217333
fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -17374,20 +17365,20 @@ struct NodeExtensionHelper {
1737417365
case RISCVISD::ADD_VL:
1737517366
case RISCVISD::MUL_VL:
1737617367
case RISCVISD::OR_VL:
17377-
case RISCVISD::VWADD_W_VL:
17378-
case RISCVISD::VWADDU_W_VL:
1737917368
case RISCVISD::FADD_VL:
1738017369
case RISCVISD::FMUL_VL:
17381-
case RISCVISD::VFWADD_W_VL:
1738217370
case RISCVISD::VFMADD_VL:
1738317371
case RISCVISD::VFNMSUB_VL:
1738417372
case RISCVISD::VFNMADD_VL:
1738517373
case RISCVISD::VFMSUB_VL:
1738617374
return true;
17375+
case RISCVISD::VWADD_W_VL:
17376+
case RISCVISD::VWADDU_W_VL:
1738717377
case ISD::SUB:
1738817378
case RISCVISD::SUB_VL:
1738917379
case RISCVISD::VWSUB_W_VL:
1739017380
case RISCVISD::VWSUBU_W_VL:
17381+
case RISCVISD::VFWADD_W_VL:
1739117382
case RISCVISD::FSUB_VL:
1739217383
case RISCVISD::VFWSUB_W_VL:
1739317384
case ISD::SHL:
@@ -17506,6 +17497,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1750617497
Subtarget);
1750717498
}
1750817499

17500+
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17501+
///
17502+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17503+
/// can be used to apply the pattern.
17504+
static std::optional<CombineResult>
17505+
canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17506+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17507+
const RISCVSubtarget &Subtarget) {
17508+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17509+
Subtarget);
17510+
}
17511+
17512+
/// Check if \p Root follows a pattern Root(ext(LHS), zext(RHS))
17513+
///
17514+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17515+
/// can be used to apply the pattern.
17516+
static std::optional<CombineResult>
17517+
canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
17518+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17519+
const RISCVSubtarget &Subtarget) {
17520+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17521+
Subtarget);
17522+
}
17523+
1750917524
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
1751017525
///
1751117526
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17534,52 +17549,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
1753417549
return std::nullopt;
1753517550
}
1753617551

17537-
/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
17552+
/// Check if \p Root follows a pattern Root(sext(LHS), RHS)
1753817553
///
1753917554
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1754017555
/// can be used to apply the pattern.
1754117556
static std::optional<CombineResult>
1754217557
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1754317558
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1754417559
const RISCVSubtarget &Subtarget) {
17545-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
17546-
Subtarget);
17560+
if (LHS.SupportsSExt)
17561+
return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
17562+
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
17563+
/*RHSExt=*/std::nullopt);
17564+
return std::nullopt;
1754717565
}
1754817566

17549-
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17567+
/// Check if \p Root follows a pattern Root(zext(LHS), RHS)
1755017568
///
1755117569
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1755217570
/// can be used to apply the pattern.
1755317571
static std::optional<CombineResult>
1755417572
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1755517573
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1755617574
const RISCVSubtarget &Subtarget) {
17557-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17558-
Subtarget);
17575+
if (LHS.SupportsZExt)
17576+
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
17577+
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
17578+
/*RHSExt=*/std::nullopt);
17579+
return std::nullopt;
1755917580
}
1756017581

17561-
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
17582+
/// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
1756217583
///
1756317584
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1756417585
/// can be used to apply the pattern.
1756517586
static std::optional<CombineResult>
1756617587
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1756717588
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1756817589
const RISCVSubtarget &Subtarget) {
17569-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
17570-
Subtarget);
17571-
}
17572-
17573-
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17574-
///
17575-
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17576-
/// can be used to apply the pattern.
17577-
static std::optional<CombineResult>
17578-
canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17579-
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17580-
const RISCVSubtarget &Subtarget) {
17581-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17582-
Subtarget);
17590+
if (LHS.SupportsFPExt)
17591+
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
17592+
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
17593+
/*RHSExt=*/std::nullopt);
17594+
return std::nullopt;
1758317595
}
1758417596

1758517597
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17622,7 +17634,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1762217634
case RISCVISD::VFNMSUB_VL:
1762317635
Strategies.push_back(canFoldToVWWithSameExtension);
1762417636
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17625-
Strategies.push_back(canFoldToVWWithBF16EXT);
17637+
Strategies.push_back(canFoldToVWWithSameExtBF16);
1762617638
break;
1762717639
case ISD::MUL:
1762817640
case RISCVISD::MUL_VL:
@@ -17634,7 +17646,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1763417646
case ISD::SHL:
1763517647
case RISCVISD::SHL_VL:
1763617648
// shl -> vwsll
17637-
Strategies.push_back(canFoldToVWWithZEXT);
17649+
Strategies.push_back(canFoldToVWWithSameExtZEXT);
1763817650
break;
1763917651
case RISCVISD::VWADD_W_VL:
1764017652
case RISCVISD::VWSUB_W_VL:

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ define <4 x i32> @pr159152(<4 x i8> %x) {
6464
; NO_FOLDING-LABEL: pr159152:
6565
; NO_FOLDING: # %bb.0:
6666
; NO_FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
67-
; NO_FOLDING-NEXT: vzext.vf2 v9, v8
67+
; NO_FOLDING-NEXT: vsext.vf2 v9, v8
6868
; NO_FOLDING-NEXT: li a0, 9
6969
; NO_FOLDING-NEXT: vwaddu.vx v8, v9, a0
7070
; NO_FOLDING-NEXT: ret
7171
;
7272
; FOLDING-LABEL: pr159152:
7373
; FOLDING: # %bb.0:
7474
; FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
75-
; FOLDING-NEXT: vzext.vf2 v9, v8
75+
; FOLDING-NEXT: vsext.vf2 v9, v8
7676
; FOLDING-NEXT: li a0, 9
7777
; FOLDING-NEXT: vwaddu.vx v8, v9, a0
7878
; FOLDING-NEXT: ret

0 commit comments

Comments
 (0)