diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 66ebda7aa586b..91e9975ba9256 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -17325,18 +17325,9 @@ struct NodeExtensionHelper { case RISCVISD::VWSUBU_W_VL: case RISCVISD::VFWADD_W_VL: case RISCVISD::VFWSUB_W_VL: - if (OperandIdx == 1) { - SupportsZExt = - Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL; - SupportsSExt = - Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL; - SupportsFPExt = - Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL; - // There's no existing extension here, so we don't have to worry about - // making sure it gets removed. - EnforceOneUse = false; + // Operand 1 can't be changed. + if (OperandIdx == 1) break; - } [[fallthrough]]; default: fillUpExtensionSupport(Root, DAG, Subtarget); @@ -17374,20 +17365,20 @@ struct NodeExtensionHelper { case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::OR_VL: - case RISCVISD::VWADD_W_VL: - case RISCVISD::VWADDU_W_VL: case RISCVISD::FADD_VL: case RISCVISD::FMUL_VL: - case RISCVISD::VFWADD_W_VL: case RISCVISD::VFMADD_VL: case RISCVISD::VFNMSUB_VL: case RISCVISD::VFNMADD_VL: case RISCVISD::VFMSUB_VL: return true; + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: case ISD::SUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: + case RISCVISD::VFWADD_W_VL: case RISCVISD::FSUB_VL: case RISCVISD::VFWSUB_W_VL: case ISD::SHL: @@ -17506,6 +17497,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, Subtarget); } +/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional +canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG, + Subtarget); +} + +/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional +canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG, + Subtarget); +} + /// Check if \p Root follows a pattern Root(LHS, ext(RHS)) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that @@ -17534,7 +17549,7 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS, return std::nullopt; } -/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS)) +/// Check if \p Root follows a pattern Root(sext(LHS), RHS) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that /// can be used to apply the pattern. @@ -17542,11 +17557,14 @@ static std::optional canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG, - Subtarget); + if (LHS.SupportsSExt) + return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS, + /*RHSExt=*/std::nullopt); + return std::nullopt; } -/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) +/// Check if \p Root follows a pattern Root(zext(LHS), RHS) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that /// can be used to apply the pattern. @@ -17554,11 +17572,14 @@ static std::optional canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG, - Subtarget); + if (LHS.SupportsZExt) + return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS, + /*RHSExt=*/std::nullopt); + return std::nullopt; } -/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS)) +/// Check if \p Root follows a pattern Root(fpext(LHS), RHS) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that /// can be used to apply the pattern. @@ -17566,20 +17587,11 @@ static std::optional canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG, - Subtarget); -} - -/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS)) -/// -/// \returns std::nullopt if the pattern doesn't match or a CombineResult that -/// can be used to apply the pattern. -static std::optional -canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG, - Subtarget); + if (LHS.SupportsFPExt) + return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS, + /*RHSExt=*/std::nullopt); + return std::nullopt; } /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) @@ -17622,7 +17634,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case RISCVISD::VFNMSUB_VL: Strategies.push_back(canFoldToVWWithSameExtension); if (Root->getOpcode() == RISCVISD::VFMADD_VL) - Strategies.push_back(canFoldToVWWithBF16EXT); + Strategies.push_back(canFoldToVWWithSameExtBF16); break; case ISD::MUL: case RISCVISD::MUL_VL: @@ -17634,7 +17646,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case ISD::SHL: case RISCVISD::SHL_VL: // shl -> vwsll - Strategies.push_back(canFoldToVWWithZEXT); + Strategies.push_back(canFoldToVWWithSameExtZEXT); break; case RISCVISD::VWADD_W_VL: case RISCVISD::VWSUB_W_VL: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll index 227a428831b60..ea4add2da5ebc 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll @@ -58,3 +58,26 @@ define <2 x i16> @vwmul_v2i16_multiple_users(ptr %x, ptr %y, ptr %z) { %i = or <2 x i16> %h, %g ret <2 x i16> %i } + +; Make sure we have a vsext.vl and a vwaddu.vx. +define <4 x i32> @pr159152(<4 x i8> %x) { +; NO_FOLDING-LABEL: pr159152: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; NO_FOLDING-NEXT: vsext.vf2 v9, v8 +; NO_FOLDING-NEXT: li a0, 9 +; NO_FOLDING-NEXT: vwaddu.vx v8, v9, a0 +; NO_FOLDING-NEXT: ret +; +; FOLDING-LABEL: pr159152: +; FOLDING: # %bb.0: +; FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; FOLDING-NEXT: vsext.vf2 v9, v8 +; FOLDING-NEXT: li a0, 9 +; FOLDING-NEXT: vwaddu.vx v8, v9, a0 +; FOLDING-NEXT: ret + %a = sext <4 x i8> %x to <4 x i16> + %b = zext <4 x i16> %a to <4 x i32> + %c = add <4 x i32> %b, + ret <4 x i32> %c +}