|
29 | 29 | #include "llvm/ADT/DenseMap.h" |
30 | 30 | #include "llvm/ADT/SmallVector.h" |
31 | 31 | #include "llvm/Analysis/TargetLibraryInfo.h" |
| 32 | +#include "llvm/Analysis/ValueTracking.h" |
32 | 33 | #include "llvm/Analysis/VectorUtils.h" |
33 | 34 | #include "llvm/CodeGen/ISDOpcodes.h" |
34 | 35 | #include "llvm/CodeGen/SelectionDAG.h" |
@@ -138,6 +139,7 @@ class VectorLegalizer { |
138 | 139 | SDValue ExpandVP_FNEG(SDNode *Node); |
139 | 140 | SDValue ExpandVP_FABS(SDNode *Node); |
140 | 141 | SDValue ExpandVP_FCOPYSIGN(SDNode *Node); |
| 142 | + SDValue ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node); |
141 | 143 | SDValue ExpandSELECT(SDNode *Node); |
142 | 144 | std::pair<SDValue, SDValue> ExpandLoad(SDNode *N); |
143 | 145 | SDValue ExpandStore(SDNode *N); |
@@ -465,6 +467,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { |
465 | 467 | case ISD::VECTOR_COMPRESS: |
466 | 468 | case ISD::SCMP: |
467 | 469 | case ISD::UCMP: |
| 470 | + case ISD::VECTOR_EXTRACT_LAST_ACTIVE: |
468 | 471 | Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0)); |
469 | 472 | break; |
470 | 473 | case ISD::SMULFIX: |
@@ -1202,6 +1205,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) { |
1202 | 1205 | case ISD::VECTOR_COMPRESS: |
1203 | 1206 | Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG)); |
1204 | 1207 | return; |
| 1208 | + case ISD::VECTOR_EXTRACT_LAST_ACTIVE: |
| 1209 | + Results.push_back(ExpandVECTOR_EXTRACT_LAST_ACTIVE(Node)); |
| 1210 | + return; |
1205 | 1211 | case ISD::SCMP: |
1206 | 1212 | case ISD::UCMP: |
1207 | 1213 | Results.push_back(TLI.expandCMP(Node, DAG)); |
@@ -1713,6 +1719,61 @@ SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) { |
1713 | 1719 | return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign); |
1714 | 1720 | } |
1715 | 1721 |
|
| 1722 | +SDValue VectorLegalizer::ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node) { |
| 1723 | + dbgs() << "Expanding extract_last_active!!\n"; |
| 1724 | + SDLoc DL(Node); |
| 1725 | + SDValue Data = Node->getOperand(0); |
| 1726 | + SDValue Mask = Node->getOperand(1); |
| 1727 | + SDValue PassThru = Node->getOperand(2); |
| 1728 | + |
| 1729 | + EVT DataVT = Data.getValueType(); |
| 1730 | + EVT ScalarVT = PassThru.getValueType(); |
| 1731 | + EVT BoolVT = Mask.getValueType().getScalarType(); |
| 1732 | + |
| 1733 | + // Find a suitable type for a stepvector. |
| 1734 | + ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value. |
| 1735 | + if (DataVT.isScalableVector()) |
| 1736 | + VScaleRange = getVScaleRange(&DAG.getMachineFunction().getFunction(), 64); |
| 1737 | + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); |
| 1738 | + unsigned EltWidth = TLI.getBitWidthForCttzElements( |
| 1739 | + ScalarVT.getTypeForEVT(*DAG.getContext()), DataVT.getVectorElementCount(), |
| 1740 | + /*ZeroIsPoison=*/true, &VScaleRange); |
| 1741 | + EVT StepVT = MVT::getIntegerVT(EltWidth); |
| 1742 | + EVT StepVecVT = DataVT.changeVectorElementType(StepVT); |
| 1743 | + |
| 1744 | + // Promote to a legal type if necessary. |
| 1745 | + if (TLI.getTypeAction(StepVecVT.getSimpleVT()) == |
| 1746 | + TargetLowering::TypePromoteInteger) { |
| 1747 | + StepVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), StepVecVT); |
| 1748 | + StepVT = StepVecVT.getVectorElementType(); |
| 1749 | + } |
| 1750 | + |
| 1751 | + // Zero out lanes with inactive elements, then find the highest remaining |
| 1752 | + // value from the stepvector. |
| 1753 | + SDValue Zeroes = DAG.getConstant(0, DL, StepVecVT); |
| 1754 | + SDValue StepVec = DAG.getStepVector(DL, StepVecVT); |
| 1755 | + SDValue ActiveElts = DAG.getSelect(DL, StepVecVT, Mask, StepVec, Zeroes); |
| 1756 | + // Unfortunately, VectorLegalizer does not recursively legalize all added |
| 1757 | + // nodes, just the end result nodes. LegalizeDAG doesn't handle VSELECT at |
| 1758 | + // all presently. So if we need to legalize a vselect then we have to do |
| 1759 | + // it here. |
| 1760 | + if (!TLI.isTypeLegal(StepVecVT) || |
| 1761 | + TLI.getOperationAction(ISD::VSELECT, StepVecVT) == TargetLowering::Expand) |
| 1762 | + ActiveElts = LegalizeOp(ActiveElts); |
| 1763 | + |
| 1764 | + SDValue HighestIdx = DAG.getNode(ISD::VECREDUCE_UMAX, DL, StepVT, ActiveElts); |
| 1765 | + |
| 1766 | + // Extract the corresponding lane from the data vector |
| 1767 | + EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout()); |
| 1768 | + SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, DL, ExtVT); |
| 1769 | + SDValue Extract = |
| 1770 | + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Data, Idx); |
| 1771 | + |
| 1772 | + // If all mask lanes were inactive, choose the passthru value instead. |
| 1773 | + SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, DL, BoolVT, Mask); |
| 1774 | + return DAG.getSelect(DL, ScalarVT, AnyActive, Extract, PassThru); |
| 1775 | +} |
| 1776 | + |
1716 | 1777 | void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node, |
1717 | 1778 | SmallVectorImpl<SDValue> &Results) { |
1718 | 1779 | // Attempt to expand using TargetLowering. |
|
0 commit comments