6363#include <bitset>
6464#include <cctype>
6565#include <numeric>
66+ #include <tuple>
6667using namespace llvm;
6768
6869#define DEBUG_TYPE "x86-isel"
@@ -44745,31 +44746,59 @@ bool X86TargetLowering::isSplatValueForTargetNode(SDValue Op,
4474544746
4474644747// Helper to peek through bitops/trunc/setcc to determine size of source vector.
4474744748// Allows combineBitcastvxi1 to determine what size vector generated a <X x i1>.
44748- static bool checkBitcastSrcVectorSize(SDValue Src, unsigned Size,
44749- bool AllowTruncate) {
44749+ static bool
44750+ checkBitcastSrcVectorSize(SDValue Src, unsigned Size, bool AllowTruncate,
44751+ std::map<std::tuple<SDValue, unsigned, bool>, bool>
44752+ &BitcastSrcVectorSizeMap) {
44753+ auto Tp = std::make_tuple(Src, Size, AllowTruncate);
44754+ if (BitcastSrcVectorSizeMap.count(Tp))
44755+ return BitcastSrcVectorSizeMap[Tp];
4475044756 switch (Src.getOpcode()) {
4475144757 case ISD::TRUNCATE:
44752- if (!AllowTruncate)
44758+ if (!AllowTruncate) {
44759+ BitcastSrcVectorSizeMap[Tp] = false;
4475344760 return false;
44761+ }
4475444762 [[fallthrough]];
44755- case ISD::SETCC:
44756- return Src.getOperand(0).getValueSizeInBits() == Size;
44757- case ISD::FREEZE:
44758- return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate);
44763+ case ISD::SETCC: {
44764+ auto Ret = Src.getOperand(0).getValueSizeInBits() == Size;
44765+ BitcastSrcVectorSizeMap[Tp] = Ret;
44766+ return Ret;
44767+ }
44768+ case ISD::FREEZE: {
44769+ auto Ret = checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate,
44770+ BitcastSrcVectorSizeMap);
44771+ BitcastSrcVectorSizeMap[Tp] = Ret;
44772+ return Ret;
44773+ }
4475944774 case ISD::AND:
4476044775 case ISD::XOR:
44761- case ISD::OR:
44762- return checkBitcastSrcVectorSize(Src.getOperand(0), Size, AllowTruncate) &&
44763- checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate);
44776+ case ISD::OR: {
44777+ auto Ret1 = checkBitcastSrcVectorSize(
44778+ Src.getOperand(0), Size, AllowTruncate, BitcastSrcVectorSizeMap);
44779+ auto Ret2 = checkBitcastSrcVectorSize(
44780+ Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap);
44781+ BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2;
44782+ return Ret1 && Ret2;
44783+ }
4476444784 case ISD::SELECT:
44765- case ISD::VSELECT:
44766- return Src.getOperand(0).getScalarValueSizeInBits() == 1 &&
44767- checkBitcastSrcVectorSize(Src.getOperand(1), Size, AllowTruncate) &&
44768- checkBitcastSrcVectorSize(Src.getOperand(2), Size, AllowTruncate);
44769- case ISD::BUILD_VECTOR:
44770- return ISD::isBuildVectorAllZeros(Src.getNode()) ||
44771- ISD::isBuildVectorAllOnes(Src.getNode());
44785+ case ISD::VSELECT: {
44786+ auto Ret1 = checkBitcastSrcVectorSize(
44787+ Src.getOperand(1), Size, AllowTruncate, BitcastSrcVectorSizeMap);
44788+ auto Ret2 = checkBitcastSrcVectorSize(
44789+ Src.getOperand(2), Size, AllowTruncate, BitcastSrcVectorSizeMap);
44790+ auto Ret3 = Src.getOperand(0).getScalarValueSizeInBits() == 1;
44791+ BitcastSrcVectorSizeMap[Tp] = Ret1 && Ret2 && Ret3;
44792+ return Ret1 && Ret2 && Ret3;
44793+ }
44794+ case ISD::BUILD_VECTOR: {
44795+ auto Ret = ISD::isBuildVectorAllZeros(Src.getNode()) ||
44796+ ISD::isBuildVectorAllOnes(Src.getNode());
44797+ BitcastSrcVectorSizeMap[Tp] = Ret;
44798+ return Ret;
4477244799 }
44800+ }
44801+ BitcastSrcVectorSizeMap[Tp] = false;
4477344802 return false;
4477444803}
4477544804
@@ -44925,6 +44954,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
4492544954 // (v16i8 shuffle <0,2,4,6,8,10,12,14,u,u,...,u> (v16i8 bitcast t0), undef)
4492644955 MVT SExtVT;
4492744956 bool PropagateSExt = false;
44957+ std::map<std::tuple<SDValue, unsigned, bool>, bool> BitcastSrcVectorSizeMap;
4492844958 switch (SrcVT.getSimpleVT().SimpleTy) {
4492944959 default:
4493044960 return SDValue();
@@ -44936,7 +44966,8 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
4493644966 // For cases such as (i4 bitcast (v4i1 setcc v4i64 v1, v2))
4493744967 // sign-extend to a 256-bit operation to avoid truncation.
4493844968 if (Subtarget.hasAVX() &&
44939- checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2())) {
44969+ checkBitcastSrcVectorSize(Src, 256, Subtarget.hasAVX2(),
44970+ BitcastSrcVectorSizeMap)) {
4494044971 SExtVT = MVT::v4i64;
4494144972 PropagateSExt = true;
4494244973 }
@@ -44948,8 +44979,9 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
4494844979 // If the setcc operand is 128-bit, prefer sign-extending to 128-bit over
4494944980 // 256-bit because the shuffle is cheaper than sign extending the result of
4495044981 // the compare.
44951- if (Subtarget.hasAVX() && (checkBitcastSrcVectorSize(Src, 256, true) ||
44952- checkBitcastSrcVectorSize(Src, 512, true))) {
44982+ if (Subtarget.hasAVX() &&
44983+ (checkBitcastSrcVectorSize(Src, 256, true, BitcastSrcVectorSizeMap) ||
44984+ checkBitcastSrcVectorSize(Src, 512, true, BitcastSrcVectorSizeMap))) {
4495344985 SExtVT = MVT::v8i32;
4495444986 PropagateSExt = true;
4495544987 }
@@ -44974,7 +45006,7 @@ static SDValue combineBitcastvxi1(SelectionDAG &DAG, EVT VT, SDValue Src,
4497445006 break;
4497545007 }
4497645008 // Split if this is a <64 x i8> comparison result.
44977- if (checkBitcastSrcVectorSize(Src, 512, false)) {
45009+ if (checkBitcastSrcVectorSize(Src, 512, false, BitcastSrcVectorSizeMap )) {
4497845010 SExtVT = MVT::v64i8;
4497945011 break;
4498045012 }
0 commit comments