|
29 | 29 | #include "llvm/ADT/DenseMap.h" |
30 | 30 | #include "llvm/ADT/SmallVector.h" |
31 | 31 | #include "llvm/Analysis/TargetLibraryInfo.h" |
| 32 | +#include "llvm/Analysis/ValueTracking.h" |
32 | 33 | #include "llvm/Analysis/VectorUtils.h" |
33 | 34 | #include "llvm/CodeGen/ISDOpcodes.h" |
34 | 35 | #include "llvm/CodeGen/SelectionDAG.h" |
@@ -138,6 +139,7 @@ class VectorLegalizer { |
138 | 139 | SDValue ExpandVP_FNEG(SDNode *Node); |
139 | 140 | SDValue ExpandVP_FABS(SDNode *Node); |
140 | 141 | SDValue ExpandVP_FCOPYSIGN(SDNode *Node); |
| 142 | + SDValue ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node); |
141 | 143 | SDValue ExpandSELECT(SDNode *Node); |
142 | 144 | std::pair<SDValue, SDValue> ExpandLoad(SDNode *N); |
143 | 145 | SDValue ExpandStore(SDNode *N); |
@@ -467,6 +469,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { |
467 | 469 | case ISD::VECTOR_COMPRESS: |
468 | 470 | case ISD::SCMP: |
469 | 471 | case ISD::UCMP: |
| 472 | + case ISD::VECTOR_EXTRACT_LAST_ACTIVE: |
470 | 473 | Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0)); |
471 | 474 | break; |
472 | 475 | case ISD::SMULFIX: |
@@ -1208,6 +1211,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) { |
1208 | 1211 | case ISD::VECTOR_COMPRESS: |
1209 | 1212 | Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG)); |
1210 | 1213 | return; |
| 1214 | + case ISD::VECTOR_EXTRACT_LAST_ACTIVE: |
| 1215 | + Results.push_back(ExpandVECTOR_EXTRACT_LAST_ACTIVE(Node)); |
| 1216 | + return; |
1211 | 1217 | case ISD::SCMP: |
1212 | 1218 | case ISD::UCMP: |
1213 | 1219 | Results.push_back(TLI.expandCMP(Node, DAG)); |
@@ -1719,6 +1725,80 @@ SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) { |
1719 | 1725 | return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign); |
1720 | 1726 | } |
1721 | 1727 |
|
| 1728 | +SDValue VectorLegalizer::ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node) { |
| 1729 | + SDLoc DL(Node); |
| 1730 | + SDValue Data = Node->getOperand(0); |
| 1731 | + SDValue Mask = Node->getOperand(1); |
| 1732 | + SDValue PassThru = Node->getOperand(2); |
| 1733 | + |
| 1734 | + EVT DataVT = Data.getValueType(); |
| 1735 | + EVT ScalarVT = PassThru.getValueType(); |
| 1736 | + EVT BoolVT = Mask.getValueType().getScalarType(); |
| 1737 | + |
| 1738 | + // Find a suitable type for a stepvector. |
| 1739 | + ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value. |
| 1740 | + if (DataVT.isScalableVector()) |
| 1741 | + VScaleRange = getVScaleRange(&DAG.getMachineFunction().getFunction(), 64); |
| 1742 | + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); |
| 1743 | + unsigned EltWidth = TLI.getBitWidthForCttzElements( |
| 1744 | + ScalarVT.getTypeForEVT(*DAG.getContext()), DataVT.getVectorElementCount(), |
| 1745 | + /*ZeroIsPoison=*/true, &VScaleRange); |
| 1746 | + |
| 1747 | + // HACK: If the target selects a VT that's too wide based on the legal types |
| 1748 | + // for a vecreduce_umax, if will force expansion of the node -- which |
| 1749 | + // doesn't work on scalable vectors... |
| 1750 | + // Is there another method we could use to get a smaller VT instead |
| 1751 | + // of just capping to 32b? |
| 1752 | + EVT StepVT = MVT::getIntegerVT(std::min(EltWidth, 32u)); |
| 1753 | + EVT StepVecVT = DataVT.changeVectorElementType(StepVT); |
| 1754 | + |
| 1755 | + // HACK: If the target selects a VT that's too small to form a legal vector |
| 1756 | + // type, we also run into problems trying to expand the vecreduce_umax. |
| 1757 | + // |
| 1758 | + // I think perhaps we need to revisit how getBitWidthForCttzElements |
| 1759 | + // works... |
| 1760 | + if (TLI.getTypeAction(StepVecVT.getSimpleVT()) == |
| 1761 | + TargetLowering::TypePromoteInteger) { |
| 1762 | + StepVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), StepVecVT); |
| 1763 | + StepVT = StepVecVT.getVectorElementType(); |
| 1764 | + } |
| 1765 | + |
| 1766 | + // Zero out lanes with inactive elements, then find the highest remaining |
| 1767 | + // value from the stepvector. |
| 1768 | + SDValue Zeroes = DAG.getConstant(0, DL, StepVecVT); |
| 1769 | + SDValue StepVec = DAG.getStepVector(DL, StepVecVT); |
| 1770 | + SDValue ActiveElts = DAG.getSelect(DL, StepVecVT, Mask, StepVec, Zeroes); |
| 1771 | + |
| 1772 | + // HACK: Unfortunately, LegalizeVectorOps does not recursively legalize *all* |
| 1773 | + // added nodes, just the end result nodes until it finds legal ops. |
| 1774 | + // LegalizeDAG doesn't handle VSELECT at all presently. So if we need to |
| 1775 | + // legalize a vselect then we have to do it here. |
| 1776 | + // |
| 1777 | + // We might want to change LegalizeVectorOps to walk backwards through the |
| 1778 | + // nodes like LegalizeDAG? And share VSELECT legalization code with |
| 1779 | + // LegalizeDAG? |
| 1780 | + // |
| 1781 | + // Or would that cause problems with illegal types that we might have just |
| 1782 | + // introduced? |
| 1783 | + // |
| 1784 | + // Having a legal op with illegal types marked as Legal should work, with the |
| 1785 | + // expectation being that type legalization fixes it up later. |
| 1786 | + if (TLI.getOperationAction(ISD::VSELECT, StepVecVT) == TargetLowering::Expand) |
| 1787 | + ActiveElts = LegalizeOp(ActiveElts); |
| 1788 | + |
| 1789 | + SDValue HighestIdx = DAG.getNode(ISD::VECREDUCE_UMAX, DL, StepVT, ActiveElts); |
| 1790 | + |
| 1791 | + // Extract the corresponding lane from the data vector |
| 1792 | + EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout()); |
| 1793 | + SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, DL, ExtVT); |
| 1794 | + SDValue Extract = |
| 1795 | + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Data, Idx); |
| 1796 | + |
| 1797 | + // If all mask lanes were inactive, choose the passthru value instead. |
| 1798 | + SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, DL, BoolVT, Mask); |
| 1799 | + return DAG.getSelect(DL, ScalarVT, AnyActive, Extract, PassThru); |
| 1800 | +} |
| 1801 | + |
1722 | 1802 | void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node, |
1723 | 1803 | SmallVectorImpl<SDValue> &Results) { |
1724 | 1804 | // Attempt to expand using TargetLowering. |
|
0 commit comments