diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index e2ae31c86bc48..754c29b6ba868 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -191,6 +191,7 @@ class SelectionDAGLegalize { SDValue ExpandExtractFromVectorThroughStack(SDValue Op); SDValue ExpandInsertToVectorThroughStack(SDValue Op); SDValue ExpandVectorBuildThroughStack(SDNode* Node); + SDValue ExpandConcatVectors(SDNode *Node); SDValue ExpandConstantFP(ConstantFPSDNode *CFP, bool UseCP); SDValue ExpandConstant(ConstantSDNode *CP); @@ -1525,6 +1526,27 @@ SDValue SelectionDAGLegalize::ExpandInsertToVectorThroughStack(SDValue Op) { BaseVecAlignment); } +SDValue SelectionDAGLegalize::ExpandConcatVectors(SDNode *Node) { + assert(Node->getOpcode() == ISD::CONCAT_VECTORS && "Unexpected opcode!"); + SDLoc DL(Node); + SmallVector Ops; + unsigned NumOperands = Node->getNumOperands(); + MVT VectorIdxType = TLI.getVectorIdxTy(DAG.getDataLayout()); + EVT VectorValueType = Node->getOperand(0).getValueType(); + unsigned NumSubElem = VectorValueType.getVectorNumElements(); + EVT ElementValueType = TLI.getTypeToTransformTo( + *DAG.getContext(), VectorValueType.getVectorElementType()); + for (unsigned I = 0; I < NumOperands; ++I) { + SDValue SubOp = Node->getOperand(I); + for (unsigned Idx = 0; Idx < NumSubElem; ++Idx) { + Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ElementValueType, + SubOp, + DAG.getConstant(Idx, DL, VectorIdxType))); + } + } + return DAG.getBuildVector(Node->getValueType(0), DL, Ops); +} + SDValue SelectionDAGLegalize::ExpandVectorBuildThroughStack(SDNode* Node) { assert((Node->getOpcode() == ISD::BUILD_VECTOR || Node->getOpcode() == ISD::CONCAT_VECTORS) && @@ -3383,7 +3405,12 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) { Results.push_back(ExpandInsertToVectorThroughStack(SDValue(Node, 0))); break; case ISD::CONCAT_VECTORS: - Results.push_back(ExpandVectorBuildThroughStack(Node)); + if (EVT VectorValueType = Node->getOperand(0).getValueType(); + VectorValueType.isScalableVector() || + TLI.isOperationExpand(ISD::EXTRACT_VECTOR_ELT, VectorValueType)) + Results.push_back(ExpandVectorBuildThroughStack(Node)); + else + Results.push_back(ExpandConcatVectors(Node)); break; case ISD::SCALAR_TO_VECTOR: Results.push_back(ExpandSCALAR_TO_VECTOR(Node));