Skip to content

Commit 6119d1f

Browse files
authored
[RISCV] Re-work how VWADD_W_VL and similar _W_VL nodes are handled in combineOp_VLToVWOp_VL. (llvm#159205)
These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension. This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))). To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source. Fixes llvm#159152.
1 parent 24b03d3 commit 6119d1f

File tree

2 files changed

+72
-37
lines changed

2 files changed

+72
-37
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17365,18 +17365,9 @@ struct NodeExtensionHelper {
1736517365
case RISCVISD::VWSUBU_W_VL:
1736617366
case RISCVISD::VFWADD_W_VL:
1736717367
case RISCVISD::VFWSUB_W_VL:
17368-
if (OperandIdx == 1) {
17369-
SupportsZExt =
17370-
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
17371-
SupportsSExt =
17372-
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
17373-
SupportsFPExt =
17374-
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
17375-
// There's no existing extension here, so we don't have to worry about
17376-
// making sure it gets removed.
17377-
EnforceOneUse = false;
17368+
// Operand 1 can't be changed.
17369+
if (OperandIdx == 1)
1737817370
break;
17379-
}
1738017371
[[fallthrough]];
1738117372
default:
1738217373
fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -17414,20 +17405,20 @@ struct NodeExtensionHelper {
1741417405
case RISCVISD::ADD_VL:
1741517406
case RISCVISD::MUL_VL:
1741617407
case RISCVISD::OR_VL:
17417-
case RISCVISD::VWADD_W_VL:
17418-
case RISCVISD::VWADDU_W_VL:
1741917408
case RISCVISD::FADD_VL:
1742017409
case RISCVISD::FMUL_VL:
17421-
case RISCVISD::VFWADD_W_VL:
1742217410
case RISCVISD::VFMADD_VL:
1742317411
case RISCVISD::VFNMSUB_VL:
1742417412
case RISCVISD::VFNMADD_VL:
1742517413
case RISCVISD::VFMSUB_VL:
1742617414
return true;
17415+
case RISCVISD::VWADD_W_VL:
17416+
case RISCVISD::VWADDU_W_VL:
1742717417
case ISD::SUB:
1742817418
case RISCVISD::SUB_VL:
1742917419
case RISCVISD::VWSUB_W_VL:
1743017420
case RISCVISD::VWSUBU_W_VL:
17421+
case RISCVISD::VFWADD_W_VL:
1743117422
case RISCVISD::FSUB_VL:
1743217423
case RISCVISD::VFWSUB_W_VL:
1743317424
case ISD::SHL:
@@ -17546,6 +17537,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1754617537
Subtarget);
1754717538
}
1754817539

17540+
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17541+
///
17542+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17543+
/// can be used to apply the pattern.
17544+
static std::optional<CombineResult>
17545+
canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17546+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17547+
const RISCVSubtarget &Subtarget) {
17548+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17549+
Subtarget);
17550+
}
17551+
17552+
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17553+
///
17554+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17555+
/// can be used to apply the pattern.
17556+
static std::optional<CombineResult>
17557+
canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
17558+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17559+
const RISCVSubtarget &Subtarget) {
17560+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17561+
Subtarget);
17562+
}
17563+
1754917564
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
1755017565
///
1755117566
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17574,52 +17589,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
1757417589
return std::nullopt;
1757517590
}
1757617591

17577-
/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
17592+
/// Check if \p Root follows a pattern Root(sext(LHS), RHS)
1757817593
///
1757917594
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1758017595
/// can be used to apply the pattern.
1758117596
static std::optional<CombineResult>
1758217597
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1758317598
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1758417599
const RISCVSubtarget &Subtarget) {
17585-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
17586-
Subtarget);
17600+
if (LHS.SupportsSExt)
17601+
return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
17602+
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
17603+
/*RHSExt=*/std::nullopt);
17604+
return std::nullopt;
1758717605
}
1758817606

17589-
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17607+
/// Check if \p Root follows a pattern Root(zext(LHS), RHS)
1759017608
///
1759117609
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1759217610
/// can be used to apply the pattern.
1759317611
static std::optional<CombineResult>
1759417612
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1759517613
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1759617614
const RISCVSubtarget &Subtarget) {
17597-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17598-
Subtarget);
17615+
if (LHS.SupportsZExt)
17616+
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
17617+
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
17618+
/*RHSExt=*/std::nullopt);
17619+
return std::nullopt;
1759917620
}
1760017621

17601-
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
17622+
/// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
1760217623
///
1760317624
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1760417625
/// can be used to apply the pattern.
1760517626
static std::optional<CombineResult>
1760617627
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1760717628
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1760817629
const RISCVSubtarget &Subtarget) {
17609-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
17610-
Subtarget);
17611-
}
17612-
17613-
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17614-
///
17615-
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17616-
/// can be used to apply the pattern.
17617-
static std::optional<CombineResult>
17618-
canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17619-
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17620-
const RISCVSubtarget &Subtarget) {
17621-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17622-
Subtarget);
17630+
if (LHS.SupportsFPExt)
17631+
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
17632+
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
17633+
/*RHSExt=*/std::nullopt);
17634+
return std::nullopt;
1762317635
}
1762417636

1762517637
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17662,7 +17674,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1766217674
case RISCVISD::VFNMSUB_VL:
1766317675
Strategies.push_back(canFoldToVWWithSameExtension);
1766417676
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17665-
Strategies.push_back(canFoldToVWWithBF16EXT);
17677+
Strategies.push_back(canFoldToVWWithSameExtBF16);
1766617678
break;
1766717679
case ISD::MUL:
1766817680
case RISCVISD::MUL_VL:
@@ -17674,7 +17686,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1767417686
case ISD::SHL:
1767517687
case RISCVISD::SHL_VL:
1767617688
// shl -> vwsll
17677-
Strategies.push_back(canFoldToVWWithZEXT);
17689+
Strategies.push_back(canFoldToVWWithSameExtZEXT);
1767817690
break;
1767917691
case RISCVISD::VWADD_W_VL:
1768017692
case RISCVISD::VWSUB_W_VL:

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,26 @@ define <2 x i16> @vwmul_v2i16_multiple_users(ptr %x, ptr %y, ptr %z) {
7575
%i = or <2 x i16> %h, %g
7676
ret <2 x i16> %i
7777
}
78+
79+
; Make sure we have a vsext.vl and a vwaddu.vx.
80+
define <4 x i32> @pr159152(<4 x i8> %x) {
81+
; NO_FOLDING-LABEL: pr159152:
82+
; NO_FOLDING: # %bb.0:
83+
; NO_FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
84+
; NO_FOLDING-NEXT: vsext.vf2 v9, v8
85+
; NO_FOLDING-NEXT: li a0, 9
86+
; NO_FOLDING-NEXT: vwaddu.vx v8, v9, a0
87+
; NO_FOLDING-NEXT: ret
88+
;
89+
; FOLDING-LABEL: pr159152:
90+
; FOLDING: # %bb.0:
91+
; FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
92+
; FOLDING-NEXT: vsext.vf2 v9, v8
93+
; FOLDING-NEXT: li a0, 9
94+
; FOLDING-NEXT: vwaddu.vx v8, v9, a0
95+
; FOLDING-NEXT: ret
96+
%a = sext <4 x i8> %x to <4 x i16>
97+
%b = zext <4 x i16> %a to <4 x i32>
98+
%c = add <4 x i32> %b, <i32 9, i32 9, i32 9, i32 9>
99+
ret <4 x i32> %c
100+
}

0 commit comments

Comments
 (0)