|
18 | 18 | #include "WebAssemblySubtarget.h"
|
19 | 19 | #include "WebAssemblyTargetMachine.h"
|
20 | 20 | #include "WebAssemblyUtilities.h"
|
| 21 | +#include "llvm/ADT/ArrayRef.h" |
21 | 22 | #include "llvm/CodeGen/CallingConvLower.h"
|
22 | 23 | #include "llvm/CodeGen/MachineFrameInfo.h"
|
23 | 24 | #include "llvm/CodeGen/MachineInstrBuilder.h"
|
@@ -91,6 +92,19 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
|
91 | 92 | setOperationAction(ISD::LOAD, T, Custom);
|
92 | 93 | setOperationAction(ISD::STORE, T, Custom);
|
93 | 94 | }
|
| 95 | + |
| 96 | + // Likewise, transform zext/sext/anyext extending loads from address space 1 |
| 97 | + // (WASM globals) |
| 98 | + setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i32, |
| 99 | + {MVT::i8, MVT::i16}, Custom); |
| 100 | + setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i64, |
| 101 | + {MVT::i8, MVT::i16, MVT::i32}, Custom); |
| 102 | + |
| 103 | + // Compensate for the EXTLOADs being custom by reimplementing some combiner |
| 104 | + // logic |
| 105 | + setTargetDAGCombine(ISD::AND); |
| 106 | + setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); |
| 107 | + |
94 | 108 | if (Subtarget->hasSIMD128()) {
|
95 | 109 | for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
|
96 | 110 | MVT::v2f64}) {
|
@@ -1707,6 +1721,11 @@ static bool IsWebAssemblyGlobal(SDValue Op) {
|
1707 | 1721 | if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op))
|
1708 | 1722 | return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
|
1709 | 1723 |
|
| 1724 | + if (Op->getOpcode() == WebAssemblyISD::Wrapper) |
| 1725 | + if (const GlobalAddressSDNode *GA = |
| 1726 | + dyn_cast<GlobalAddressSDNode>(Op->getOperand(0))) |
| 1727 | + return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace()); |
| 1728 | + |
1710 | 1729 | return false;
|
1711 | 1730 | }
|
1712 | 1731 |
|
@@ -1764,16 +1783,115 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
|
1764 | 1783 | LoadSDNode *LN = cast<LoadSDNode>(Op.getNode());
|
1765 | 1784 | const SDValue &Base = LN->getBasePtr();
|
1766 | 1785 | const SDValue &Offset = LN->getOffset();
|
| 1786 | + ISD::LoadExtType ExtType = LN->getExtensionType(); |
| 1787 | + EVT ResultType = LN->getValueType(0); |
1767 | 1788 |
|
1768 | 1789 | if (IsWebAssemblyGlobal(Base)) {
|
1769 | 1790 | if (!Offset->isUndef())
|
1770 | 1791 | report_fatal_error(
|
1771 | 1792 | "unexpected offset when loading from webassembly global", false);
|
1772 | 1793 |
|
1773 |
| - SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other); |
1774 |
| - SDValue Ops[] = {LN->getChain(), Base}; |
1775 |
| - return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, |
1776 |
| - LN->getMemoryVT(), LN->getMemOperand()); |
| 1794 | + if (!ResultType.isInteger() && !ResultType.isFloatingPoint()) { |
| 1795 | + SDVTList Tys = DAG.getVTList(ResultType, MVT::Other); |
| 1796 | + SDValue Ops[] = {LN->getChain(), Base}; |
| 1797 | + SDValue GlobalGetNode = |
| 1798 | + DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, |
| 1799 | + LN->getMemoryVT(), LN->getMemOperand()); |
| 1800 | + return GlobalGetNode; |
| 1801 | + } |
| 1802 | + |
| 1803 | + EVT GT = MVT::INVALID_SIMPLE_VALUE_TYPE; |
| 1804 | + |
| 1805 | + if (auto *GA = dyn_cast<GlobalAddressSDNode>( |
| 1806 | + Base->getOpcode() == WebAssemblyISD::Wrapper ? Base->getOperand(0) |
| 1807 | + : Base)) |
| 1808 | + GT = EVT::getEVT(GA->getGlobal()->getValueType()); |
| 1809 | + |
| 1810 | + if (GT != MVT::i8 && GT != MVT::i16 && GT != MVT::i32 && GT != MVT::i64 && |
| 1811 | + GT != MVT::f32 && GT != MVT::f64) |
| 1812 | + report_fatal_error("encountered unexpected global type for Base when " |
| 1813 | + "loading from webassembly global", |
| 1814 | + false); |
| 1815 | + |
| 1816 | + EVT PromotedGT = getTypeToTransformTo(*DAG.getContext(), GT); |
| 1817 | + |
| 1818 | + if (ExtType == ISD::NON_EXTLOAD) { |
| 1819 | + // A normal, non-extending load may try to load more or less than the |
| 1820 | + // underlying global, which is invalid. We lower this to a load of the |
| 1821 | + // global (i32 or i64) then truncate or extend as needed |
| 1822 | + |
| 1823 | + // Modify the MMO to load the full global |
| 1824 | + MachineMemOperand *OldMMO = LN->getMemOperand(); |
| 1825 | + MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand( |
| 1826 | + OldMMO->getPointerInfo(), OldMMO->getFlags(), |
| 1827 | + LLT(PromotedGT.getSimpleVT()), OldMMO->getBaseAlign(), |
| 1828 | + OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(), |
| 1829 | + OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering()); |
| 1830 | + |
| 1831 | + SDVTList Tys = DAG.getVTList(PromotedGT, MVT::Other); |
| 1832 | + SDValue Ops[] = {LN->getChain(), Base}; |
| 1833 | + SDValue GlobalGetNode = DAG.getMemIntrinsicNode( |
| 1834 | + WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, PromotedGT, NewMMO); |
| 1835 | + |
| 1836 | + if (ResultType.bitsEq(PromotedGT)) { |
| 1837 | + return GlobalGetNode; |
| 1838 | + } |
| 1839 | + |
| 1840 | + SDValue ValRes; |
| 1841 | + if (ResultType.isFloatingPoint()) |
| 1842 | + ValRes = DAG.getFPExtendOrRound(GlobalGetNode, DL, ResultType); |
| 1843 | + else |
| 1844 | + ValRes = DAG.getAnyExtOrTrunc(GlobalGetNode, DL, ResultType); |
| 1845 | + |
| 1846 | + return DAG.getMergeValues({ValRes, GlobalGetNode.getValue(1)}, DL); |
| 1847 | + } |
| 1848 | + |
| 1849 | + if (ExtType == ISD::ZEXTLOAD || ExtType == ISD::SEXTLOAD) { |
| 1850 | + // Turn the unsupported load into an EXTLOAD followed by an |
| 1851 | + // explicit zero/sign extend inreg. Same as Expand |
| 1852 | + |
| 1853 | + SDValue Result = |
| 1854 | + DAG.getExtLoad(ISD::EXTLOAD, DL, ResultType, LN->getChain(), Base, |
| 1855 | + LN->getMemoryVT(), LN->getMemOperand()); |
| 1856 | + SDValue ValRes; |
| 1857 | + if (ExtType == ISD::SEXTLOAD) |
| 1858 | + ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Result.getValueType(), |
| 1859 | + Result, DAG.getValueType(LN->getMemoryVT())); |
| 1860 | + else |
| 1861 | + ValRes = DAG.getZeroExtendInReg(Result, DL, LN->getMemoryVT()); |
| 1862 | + |
| 1863 | + return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL); |
| 1864 | + } |
| 1865 | + |
| 1866 | + if (ExtType == ISD::EXTLOAD) { |
| 1867 | + // Expand the EXTLOAD into a regular LOAD of the global, and if |
| 1868 | + // needed, a zero-extension |
| 1869 | + |
| 1870 | + EVT OldLoadType = LN->getMemoryVT(); |
| 1871 | + EVT NewLoadType = getTypeToTransformTo(*DAG.getContext(), OldLoadType); |
| 1872 | + |
| 1873 | + // Modify the MMO to load a whole WASM "register"'s worth |
| 1874 | + MachineMemOperand *OldMMO = LN->getMemOperand(); |
| 1875 | + MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand( |
| 1876 | + OldMMO->getPointerInfo(), OldMMO->getFlags(), |
| 1877 | + LLT(NewLoadType.getSimpleVT()), OldMMO->getBaseAlign(), |
| 1878 | + OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(), |
| 1879 | + OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering()); |
| 1880 | + |
| 1881 | + SDValue Result = |
| 1882 | + DAG.getLoad(NewLoadType, DL, LN->getChain(), Base, NewMMO); |
| 1883 | + |
| 1884 | + if (NewLoadType != ResultType) { |
| 1885 | + SDValue ValRes = DAG.getNode(ISD::ANY_EXTEND, DL, ResultType, Result); |
| 1886 | + return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL); |
| 1887 | + } |
| 1888 | + |
| 1889 | + return Result; |
| 1890 | + } |
| 1891 | + |
| 1892 | + report_fatal_error( |
| 1893 | + "encountered unexpected ExtType when loading from webassembly global", |
| 1894 | + false); |
1777 | 1895 | }
|
1778 | 1896 |
|
1779 | 1897 | if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) {
|
@@ -3637,6 +3755,184 @@ static SDValue performMulCombine(SDNode *N,
|
3637 | 3755 | }
|
3638 | 3756 | }
|
3639 | 3757 |
|
| 3758 | +static SDValue performANDCombine(SDNode *N, |
| 3759 | + TargetLowering::DAGCombinerInfo &DCI) { |
| 3760 | + // Copied and modified from DAGCombiner::visitAND(SDNode *N) |
| 3761 | + // We have to do this because the original combiner doesn't work when ZEXTLOAD |
| 3762 | + // has custom lowering |
| 3763 | + |
| 3764 | + SDValue N0 = N->getOperand(0); |
| 3765 | + SDValue N1 = N->getOperand(1); |
| 3766 | + SDLoc DL(N); |
| 3767 | + |
| 3768 | + // fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) -> |
| 3769 | + // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must |
| 3770 | + // already be zero by virtue of the width of the base type of the load. |
| 3771 | + // |
| 3772 | + // the 'X' node here can either be nothing or an extract_vector_elt to catch |
| 3773 | + // more cases. |
| 3774 | + if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && |
| 3775 | + N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() && |
| 3776 | + N0.getOperand(0).getOpcode() == ISD::LOAD && |
| 3777 | + N0.getOperand(0).getResNo() == 0) || |
| 3778 | + (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) { |
| 3779 | + auto *Load = |
| 3780 | + cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0)); |
| 3781 | + |
| 3782 | + // Get the constant (if applicable) the zero'th operand is being ANDed with. |
| 3783 | + // This can be a pure constant or a vector splat, in which case we treat the |
| 3784 | + // vector as a scalar and use the splat value. |
| 3785 | + APInt Constant = APInt::getZero(1); |
| 3786 | + if (const ConstantSDNode *C = isConstOrConstSplat( |
| 3787 | + N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) { |
| 3788 | + Constant = C->getAPIntValue(); |
| 3789 | + } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) { |
| 3790 | + unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits(); |
| 3791 | + APInt SplatValue, SplatUndef; |
| 3792 | + unsigned SplatBitSize; |
| 3793 | + bool HasAnyUndefs; |
| 3794 | + // Endianness should not matter here. Code below makes sure that we only |
| 3795 | + // use the result if the SplatBitSize is a multiple of the vector element |
| 3796 | + // size. And after that we AND all element sized parts of the splat |
| 3797 | + // together. So the end result should be the same regardless of in which |
| 3798 | + // order we do those operations. |
| 3799 | + const bool IsBigEndian = false; |
| 3800 | + bool IsSplat = |
| 3801 | + Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, |
| 3802 | + HasAnyUndefs, EltBitWidth, IsBigEndian); |
| 3803 | + |
| 3804 | + // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a |
| 3805 | + // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value. |
| 3806 | + if (IsSplat && (SplatBitSize % EltBitWidth) == 0) { |
| 3807 | + // Undef bits can contribute to a possible optimisation if set, so |
| 3808 | + // set them. |
| 3809 | + SplatValue |= SplatUndef; |
| 3810 | + |
| 3811 | + // The splat value may be something like "0x00FFFFFF", which means 0 for |
| 3812 | + // the first vector value and FF for the rest, repeating. We need a mask |
| 3813 | + // that will apply equally to all members of the vector, so AND all the |
| 3814 | + // lanes of the constant together. |
| 3815 | + Constant = APInt::getAllOnes(EltBitWidth); |
| 3816 | + for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i) |
| 3817 | + Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth); |
| 3818 | + } |
| 3819 | + } |
| 3820 | + |
| 3821 | + // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is |
| 3822 | + // actually legal and isn't going to get expanded, else this is a false |
| 3823 | + // optimisation. |
| 3824 | + |
| 3825 | + /*bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD, |
| 3826 | + Load->getValueType(0), |
| 3827 | + Load->getMemoryVT());*/ |
| 3828 | + // MODIFIED: this is the one difference in the logic; we allow ZEXT combine |
| 3829 | + // only in addrspace 0, where it's legal |
| 3830 | + bool CanZextLoadProfitably = Load->getAddressSpace() == 0; |
| 3831 | + |
| 3832 | + // Resize the constant to the same size as the original memory access before |
| 3833 | + // extension. If it is still the AllOnesValue then this AND is completely |
| 3834 | + // unneeded. |
| 3835 | + Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits()); |
| 3836 | + |
| 3837 | + bool B; |
| 3838 | + switch (Load->getExtensionType()) { |
| 3839 | + default: |
| 3840 | + B = false; |
| 3841 | + break; |
| 3842 | + case ISD::EXTLOAD: |
| 3843 | + B = CanZextLoadProfitably; |
| 3844 | + break; |
| 3845 | + case ISD::ZEXTLOAD: |
| 3846 | + case ISD::NON_EXTLOAD: |
| 3847 | + B = true; |
| 3848 | + break; |
| 3849 | + } |
| 3850 | + |
| 3851 | + if (B && Constant.isAllOnes()) { |
| 3852 | + // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to |
| 3853 | + // preserve semantics once we get rid of the AND. |
| 3854 | + SDValue NewLoad(Load, 0); |
| 3855 | + |
| 3856 | + // Fold the AND away. NewLoad may get replaced immediately. |
| 3857 | + DCI.CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0); |
| 3858 | + |
| 3859 | + if (Load->getExtensionType() == ISD::EXTLOAD) { |
| 3860 | + NewLoad = DCI.DAG.getLoad( |
| 3861 | + Load->getAddressingMode(), ISD::ZEXTLOAD, Load->getValueType(0), |
| 3862 | + SDLoc(Load), Load->getChain(), Load->getBasePtr(), |
| 3863 | + Load->getOffset(), Load->getMemoryVT(), Load->getMemOperand()); |
| 3864 | + // Replace uses of the EXTLOAD with the new ZEXTLOAD. |
| 3865 | + if (Load->getNumValues() == 3) { |
| 3866 | + // PRE/POST_INC loads have 3 values. |
| 3867 | + SDValue To[] = {NewLoad.getValue(0), NewLoad.getValue(1), |
| 3868 | + NewLoad.getValue(2)}; |
| 3869 | + DCI.CombineTo(Load, ArrayRef<SDValue>(To, 3), true); |
| 3870 | + } else { |
| 3871 | + DCI.CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1)); |
| 3872 | + } |
| 3873 | + } |
| 3874 | + |
| 3875 | + return SDValue(N, 0); // Return N so it doesn't get rechecked! |
| 3876 | + } |
| 3877 | + } |
| 3878 | + return SDValue(); |
| 3879 | +} |
| 3880 | + |
| 3881 | +static SDValue |
| 3882 | +performSIGN_EXTEND_INREGCombine(SDNode *N, |
| 3883 | + TargetLowering::DAGCombinerInfo &DCI) { |
| 3884 | + // Copied and modified from DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) |
| 3885 | + // We have to do this because the original combiner doesn't work when SEXTLOAD |
| 3886 | + // has custom lowering |
| 3887 | + |
| 3888 | + SDValue N0 = N->getOperand(0); |
| 3889 | + SDValue N1 = N->getOperand(1); |
| 3890 | + EVT VT = N->getValueType(0); |
| 3891 | + EVT ExtVT = cast<VTSDNode>(N1)->getVT(); |
| 3892 | + SDLoc DL(N); |
| 3893 | + |
| 3894 | + // fold (sext_inreg (extload x)) -> (sextload x) |
| 3895 | + // If sextload is not supported by target, we can only do the combine when |
| 3896 | + // load has one use. Doing otherwise can block folding the extload with other |
| 3897 | + // extends that the target does support. |
| 3898 | + |
| 3899 | + // MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with |
| 3900 | + // cast<LoadSDNode>(N0)->getAddressSpace() == 0) |
| 3901 | + if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && |
| 3902 | + ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && |
| 3903 | + ((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple() && |
| 3904 | + N0.hasOneUse()) || |
| 3905 | + cast<LoadSDNode>(N0)->getAddressSpace() == 0)) { |
| 3906 | + auto *LN0 = cast<LoadSDNode>(N0); |
| 3907 | + SDValue ExtLoad = |
| 3908 | + DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), |
| 3909 | + LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); |
| 3910 | + DCI.CombineTo(N, ExtLoad); |
| 3911 | + DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); |
| 3912 | + DCI.AddToWorklist(ExtLoad.getNode()); |
| 3913 | + return SDValue(N, 0); // Return N so it doesn't get rechecked! |
| 3914 | + } |
| 3915 | + |
| 3916 | + // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use |
| 3917 | + |
| 3918 | + // MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with |
| 3919 | + // cast<LoadSDNode>(N0)->getAddressSpace() == 0) |
| 3920 | + if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && |
| 3921 | + N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && |
| 3922 | + ((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple()) && |
| 3923 | + cast<LoadSDNode>(N0)->getAddressSpace() == 0)) { |
| 3924 | + auto *LN0 = cast<LoadSDNode>(N0); |
| 3925 | + SDValue ExtLoad = |
| 3926 | + DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), |
| 3927 | + LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); |
| 3928 | + DCI.CombineTo(N, ExtLoad); |
| 3929 | + DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); |
| 3930 | + return SDValue(N, 0); // Return N so it doesn't get rechecked! |
| 3931 | + } |
| 3932 | + |
| 3933 | + return SDValue(); |
| 3934 | +} |
| 3935 | + |
3640 | 3936 | SDValue
|
3641 | 3937 | WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
|
3642 | 3938 | DAGCombinerInfo &DCI) const {
|
@@ -3672,5 +3968,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
|
3672 | 3968 | }
|
3673 | 3969 | case ISD::MUL:
|
3674 | 3970 | return performMulCombine(N, DCI);
|
| 3971 | + case ISD::AND: |
| 3972 | + return performANDCombine(N, DCI); |
| 3973 | + case ISD::SIGN_EXTEND_INREG: |
| 3974 | + return performSIGN_EXTEND_INREGCombine(N, DCI); |
3675 | 3975 | }
|
3676 | 3976 | }
|
0 commit comments