From 4574283c27179ae0cb5b1ebbf7e7053551a7a79b Mon Sep 17 00:00:00 2001 From: "Yang, Haonan" Date: Fri, 7 Mar 2025 04:23:08 +0100 Subject: [PATCH 1/2] [X86] Fix duplicated compute in recursive search. --- llvm/lib/Target/X86/X86ISelLowering.cpp | 74 ++++++++++++++++++------- 1 file changed, 53 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index deab638b7e546..1e58222f9bea1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -63,6 +63,7 @@ #include #include #include +#include using namespace llvm; #define DEBUG_TYPE "x86-isel" @@ -44745,31 +44746,59 @@ bool X86TargetLowering::isSplatValueForTargetNode(SDValue Op, // Helper to peek through bitops/trunc/setcc to determine size of source vector. // Allows combineBitcastvxi1 to determine what size vector generated a . -static bool checkBitcastSrcVectorSize(SDValue Src, unsigned Size, - bool AllowTruncate) { +static bool +checkBitcastSrcVectorSize(SDValue Src, unsigned Size, bool AllowTruncate, + std::map, bool> + &BitcastSrcVectorSizeMap) { + auto Tp = std::make_tuple(Src, Size, AllowTruncate); + if (BitcastSrcVectorSizeMap.count(Tp)) + return BitcastSrcVectorSizeMap[Tp]; switch (Src.getOpcode()) { case ISD::TRUNCATE: - if (!AllowTruncate) + if (!AllowTruncate) { + BitcastSrcVectorSizeMap[Tp] = false; return false; + } [[fallthrough]]; - case ISD::SETCC: - return Src.getOperand(0).getValueSizeInBits() == Size; - case ISD::FREEZE: - return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate); + case ISD::SETCC: { + auto Ret = Src.getOperand(0).getValueSizeInBits() == Size; + BitcastSrcVectorSizeMap[Tp] = Ret; + return Ret; + } + case ISD::FREEZE: { + auto Ret = checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate, + BitcastSrcVectorSizeMap); + BitcastSrcVectorSizeMap[Tp] = Ret; + return Ret; + } case ISD::AND: case ISD::XOR: - case ISD::OR: - return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate) && - checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate); + case ISD::OR: { + auto Ret1 = checkBitcastSrcVectorSize( + Src.getOperand(0), Size, AllowTruncate, BitcastSrcVectorSizeMap); + auto Ret2 = checkBitcastSrcVectorSize( + Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap); + BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2; + return Ret1 && Ret2; + } case ISD::SELECT: - case ISD::VSELECT: - return Src.getOperand(0).getScalarValueSizeInBits() == 1 && - checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate) && - checkBitcastSrcVectorSize(Src.getOperand(2), Size, AllowTruncate); - case ISD::BUILD_VECTOR: - return ISD::isBuildVectorAllZeros(Src.getNode()) || - ISD::isBuildVectorAllOnes(Src.getNode()); + case ISD::VSELECT: { + auto Ret1 = checkBitcastSrcVectorSize( + Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap); + auto Ret2 = checkBitcastSrcVectorSize( + Src.getOperand(2), Size, AllowTruncate, BitcastSrcVectorSizeMap); + auto Ret3 = Src.getOperand(0).getScalarValueSizeInBits() == 1; + BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2 && Ret3; + return Ret1 && Ret2 && Ret3; + } + case ISD::BUILD_VECTOR: { + auto Ret = ISD::isBuildVectorAllZeros(Src.getNode()) || + ISD::isBuildVectorAllOnes(Src.getNode()); + BitcastSrcVectorSizeMap[Tp] = Ret; + return Ret; } + } + BitcastSrcVectorSizeMap[Tp] = false; return false; } @@ -44925,6 +44954,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, // (v16i8 shuffle <0,2,4,6,8,10,12,14,u,u,...,u> (v16i8 bitcast t0), undef) MVT SExtVT; bool PropagateSExt = false; + std::map, bool> BitcastSrcVectorSizeMap; switch (SrcVT.getSimpleVT().SimpleTy) { default: return SDValue(); @@ -44936,7 +44966,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2)) // sign-extend to a 256-bit operation to avoid truncation. if (Subtarget.hasAVX() && - checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2())) { + checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2(), + BitcastSrcVectorSizeMap)) { SExtVT = MVT::v4i64; PropagateSExt = true; } @@ -44948,8 +44979,9 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, // If the setcc operand is 128-bit, prefer sign-extending to 128-bit over // 256-bit because the shuffle is cheaper than sign extending the result of // the compare. - if (Subtarget.hasAVX() && (checkBitcastSrcVectorSize(Src, 256, true) || - checkBitcastSrcVectorSize(Src, 512, true))) { + if (Subtarget.hasAVX() && + (checkBitcastSrcVectorSize(Src, 256, true, BitcastSrcVectorSizeMap) || + checkBitcastSrcVectorSize(Src, 512, true, BitcastSrcVectorSizeMap))) { SExtVT = MVT::v8i32; PropagateSExt = true; } @@ -44974,7 +45006,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, break; } // Split if this is a <64 x i8> comparison result. - if (checkBitcastSrcVectorSize(Src, 512, false)) { + if (checkBitcastSrcVectorSize(Src, 512, false, BitcastSrcVectorSizeMap)) { SExtVT = MVT::v64i8; break; } From b5aeeeebd53ff266f107e136f2ad2a837338440a Mon Sep 17 00:00:00 2001 From: "Yang, Haonan" Date: Mon, 10 Mar 2025 03:27:07 +0100 Subject: [PATCH 2/2] Add a recursive depth limit. --- llvm/lib/Target/X86/X86ISelLowering.cpp | 82 +++++++++---------------- 1 file changed, 29 insertions(+), 53 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 1e58222f9bea1..26b8cb029c9a8 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -63,7 +63,6 @@ #include #include #include -#include using namespace llvm; #define DEBUG_TYPE "x86-isel" @@ -44746,59 +44745,39 @@ bool X86TargetLowering::isSplatValueForTargetNode(SDValue Op, // Helper to peek through bitops/trunc/setcc to determine size of source vector. // Allows combineBitcastvxi1 to determine what size vector generated a . -static bool -checkBitcastSrcVectorSize(SDValue Src, unsigned Size, bool AllowTruncate, - std::map, bool> - &BitcastSrcVectorSizeMap) { - auto Tp = std::make_tuple(Src, Size, AllowTruncate); - if (BitcastSrcVectorSizeMap.count(Tp)) - return BitcastSrcVectorSizeMap[Tp]; +static bool checkBitcastSrcVectorSize(SDValue Src, unsigned Size, + bool AllowTruncate, unsigned Depth) { + // Limit recursion. + if (Depth >= SelectionDAG::MaxRecursionDepth) + return false; switch (Src.getOpcode()) { case ISD::TRUNCATE: - if (!AllowTruncate) { - BitcastSrcVectorSizeMap[Tp] = false; + if (!AllowTruncate) return false; - } [[fallthrough]]; - case ISD::SETCC: { - auto Ret = Src.getOperand(0).getValueSizeInBits() == Size; - BitcastSrcVectorSizeMap[Tp] = Ret; - return Ret; - } - case ISD::FREEZE: { - auto Ret = checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate, - BitcastSrcVectorSizeMap); - BitcastSrcVectorSizeMap[Tp] = Ret; - return Ret; - } + case ISD::SETCC: + return Src.getOperand(0).getValueSizeInBits() == Size; + case ISD::FREEZE: + return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate, + Depth + 1); case ISD::AND: case ISD::XOR: - case ISD::OR: { - auto Ret1 = checkBitcastSrcVectorSize( - Src.getOperand(0), Size, AllowTruncate, BitcastSrcVectorSizeMap); - auto Ret2 = checkBitcastSrcVectorSize( - Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap); - BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2; - return Ret1 && Ret2; - } + case ISD::OR: + return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate, + Depth + 1) && + checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate, + Depth + 1); case ISD::SELECT: - case ISD::VSELECT: { - auto Ret1 = checkBitcastSrcVectorSize( - Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap); - auto Ret2 = checkBitcastSrcVectorSize( - Src.getOperand(2), Size, AllowTruncate, BitcastSrcVectorSizeMap); - auto Ret3 = Src.getOperand(0).getScalarValueSizeInBits() == 1; - BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2 && Ret3; - return Ret1 && Ret2 && Ret3; - } - case ISD::BUILD_VECTOR: { - auto Ret = ISD::isBuildVectorAllZeros(Src.getNode()) || - ISD::isBuildVectorAllOnes(Src.getNode()); - BitcastSrcVectorSizeMap[Tp] = Ret; - return Ret; - } + case ISD::VSELECT: + return Src.getOperand(0).getScalarValueSizeInBits() == 1 && + checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate, + Depth + 1) && + checkBitcastSrcVectorSize(Src.getOperand(2), Size, AllowTruncate, + Depth + 1); + case ISD::BUILD_VECTOR: + return ISD::isBuildVectorAllZeros(Src.getNode()) || + ISD::isBuildVectorAllOnes(Src.getNode()); } - BitcastSrcVectorSizeMap[Tp] = false; return false; } @@ -44954,7 +44933,6 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, // (v16i8 shuffle <0,2,4,6,8,10,12,14,u,u,...,u> (v16i8 bitcast t0), undef) MVT SExtVT; bool PropagateSExt = false; - std::map, bool> BitcastSrcVectorSizeMap; switch (SrcVT.getSimpleVT().SimpleTy) { default: return SDValue(); @@ -44966,8 +44944,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2)) // sign-extend to a 256-bit operation to avoid truncation. if (Subtarget.hasAVX() && - checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2(), - BitcastSrcVectorSizeMap)) { + checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2(), 0)) { SExtVT = MVT::v4i64; PropagateSExt = true; } @@ -44979,9 +44956,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, // If the setcc operand is 128-bit, prefer sign-extending to 128-bit over // 256-bit because the shuffle is cheaper than sign extending the result of // the compare. - if (Subtarget.hasAVX() && - (checkBitcastSrcVectorSize(Src, 256, true, BitcastSrcVectorSizeMap) || - checkBitcastSrcVectorSize(Src, 512, true, BitcastSrcVectorSizeMap))) { + if (Subtarget.hasAVX() && (checkBitcastSrcVectorSize(Src, 256, true, 0) || + checkBitcastSrcVectorSize(Src, 512, true, 0))) { SExtVT = MVT::v8i32; PropagateSExt = true; } @@ -45006,7 +44982,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src, break; } // Split if this is a <64 x i8> comparison result. - if (checkBitcastSrcVectorSize(Src, 512, false, BitcastSrcVectorSizeMap)) { + if (checkBitcastSrcVectorSize(Src, 512, false, 0)) { SExtVT = MVT::v64i8; break; }