Skip to content

Commit 584c522

Browse files
Fix lowering of loads (and extending loads) from addrspace(1) globals
1 parent d5aaf83 commit 584c522

File tree

3 files changed

+337
-30
lines changed

3 files changed

+337
-30
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 304 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "WebAssemblySubtarget.h"
1919
#include "WebAssemblyTargetMachine.h"
2020
#include "WebAssemblyUtilities.h"
21+
#include "llvm/ADT/ArrayRef.h"
2122
#include "llvm/CodeGen/CallingConvLower.h"
2223
#include "llvm/CodeGen/MachineFrameInfo.h"
2324
#include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -91,6 +92,19 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
9192
setOperationAction(ISD::LOAD, T, Custom);
9293
setOperationAction(ISD::STORE, T, Custom);
9394
}
95+
96+
// Likewise, transform zext/sext/anyext extending loads from address space 1
97+
// (WASM globals)
98+
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i32,
99+
{MVT::i8, MVT::i16}, Custom);
100+
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i64,
101+
{MVT::i8, MVT::i16, MVT::i32}, Custom);
102+
103+
// Compensate for the EXTLOADs being custom by reimplementing some combiner
104+
// logic
105+
setTargetDAGCombine(ISD::AND);
106+
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
107+
94108
if (Subtarget->hasSIMD128()) {
95109
for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
96110
MVT::v2f64}) {
@@ -1707,6 +1721,11 @@ static bool IsWebAssemblyGlobal(SDValue Op) {
17071721
if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op))
17081722
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
17091723

1724+
if (Op->getOpcode() == WebAssemblyISD::Wrapper)
1725+
if (const GlobalAddressSDNode *GA =
1726+
dyn_cast<GlobalAddressSDNode>(Op->getOperand(0)))
1727+
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
1728+
17101729
return false;
17111730
}
17121731

@@ -1764,16 +1783,115 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
17641783
LoadSDNode *LN = cast<LoadSDNode>(Op.getNode());
17651784
const SDValue &Base = LN->getBasePtr();
17661785
const SDValue &Offset = LN->getOffset();
1786+
ISD::LoadExtType ExtType = LN->getExtensionType();
1787+
EVT ResultType = LN->getValueType(0);
17671788

17681789
if (IsWebAssemblyGlobal(Base)) {
17691790
if (!Offset->isUndef())
17701791
report_fatal_error(
17711792
"unexpected offset when loading from webassembly global", false);
17721793

1773-
SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other);
1774-
SDValue Ops[] = {LN->getChain(), Base};
1775-
return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
1776-
LN->getMemoryVT(), LN->getMemOperand());
1794+
if (!ResultType.isInteger() && !ResultType.isFloatingPoint()) {
1795+
SDVTList Tys = DAG.getVTList(ResultType, MVT::Other);
1796+
SDValue Ops[] = {LN->getChain(), Base};
1797+
SDValue GlobalGetNode =
1798+
DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
1799+
LN->getMemoryVT(), LN->getMemOperand());
1800+
return GlobalGetNode;
1801+
}
1802+
1803+
EVT GT = MVT::INVALID_SIMPLE_VALUE_TYPE;
1804+
1805+
if (auto *GA = dyn_cast<GlobalAddressSDNode>(
1806+
Base->getOpcode() == WebAssemblyISD::Wrapper ? Base->getOperand(0)
1807+
: Base))
1808+
GT = EVT::getEVT(GA->getGlobal()->getValueType());
1809+
1810+
if (GT != MVT::i8 && GT != MVT::i16 && GT != MVT::i32 && GT != MVT::i64 &&
1811+
GT != MVT::f32 && GT != MVT::f64)
1812+
report_fatal_error("encountered unexpected global type for Base when "
1813+
"loading from webassembly global",
1814+
false);
1815+
1816+
EVT PromotedGT = getTypeToTransformTo(*DAG.getContext(), GT);
1817+
1818+
if (ExtType == ISD::NON_EXTLOAD) {
1819+
// A normal, non-extending load may try to load more or less than the
1820+
// underlying global, which is invalid. We lower this to a load of the
1821+
// global (i32 or i64) then truncate or extend as needed
1822+
1823+
// Modify the MMO to load the full global
1824+
MachineMemOperand *OldMMO = LN->getMemOperand();
1825+
MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
1826+
OldMMO->getPointerInfo(), OldMMO->getFlags(),
1827+
LLT(PromotedGT.getSimpleVT()), OldMMO->getBaseAlign(),
1828+
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
1829+
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());
1830+
1831+
SDVTList Tys = DAG.getVTList(PromotedGT, MVT::Other);
1832+
SDValue Ops[] = {LN->getChain(), Base};
1833+
SDValue GlobalGetNode = DAG.getMemIntrinsicNode(
1834+
WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, PromotedGT, NewMMO);
1835+
1836+
if (ResultType.bitsEq(PromotedGT)) {
1837+
return GlobalGetNode;
1838+
}
1839+
1840+
SDValue ValRes;
1841+
if (ResultType.isFloatingPoint())
1842+
ValRes = DAG.getFPExtendOrRound(GlobalGetNode, DL, ResultType);
1843+
else
1844+
ValRes = DAG.getAnyExtOrTrunc(GlobalGetNode, DL, ResultType);
1845+
1846+
return DAG.getMergeValues({ValRes, GlobalGetNode.getValue(1)}, DL);
1847+
}
1848+
1849+
if (ExtType == ISD::ZEXTLOAD || ExtType == ISD::SEXTLOAD) {
1850+
// Turn the unsupported load into an EXTLOAD followed by an
1851+
// explicit zero/sign extend inreg. Same as Expand
1852+
1853+
SDValue Result =
1854+
DAG.getExtLoad(ISD::EXTLOAD, DL, ResultType, LN->getChain(), Base,
1855+
LN->getMemoryVT(), LN->getMemOperand());
1856+
SDValue ValRes;
1857+
if (ExtType == ISD::SEXTLOAD)
1858+
ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Result.getValueType(),
1859+
Result, DAG.getValueType(LN->getMemoryVT()));
1860+
else
1861+
ValRes = DAG.getZeroExtendInReg(Result, DL, LN->getMemoryVT());
1862+
1863+
return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL);
1864+
}
1865+
1866+
if (ExtType == ISD::EXTLOAD) {
1867+
// Expand the EXTLOAD into a regular LOAD of the global, and if
1868+
// needed, a zero-extension
1869+
1870+
EVT OldLoadType = LN->getMemoryVT();
1871+
EVT NewLoadType = getTypeToTransformTo(*DAG.getContext(), OldLoadType);
1872+
1873+
// Modify the MMO to load a whole WASM "register"'s worth
1874+
MachineMemOperand *OldMMO = LN->getMemOperand();
1875+
MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
1876+
OldMMO->getPointerInfo(), OldMMO->getFlags(),
1877+
LLT(NewLoadType.getSimpleVT()), OldMMO->getBaseAlign(),
1878+
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
1879+
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());
1880+
1881+
SDValue Result =
1882+
DAG.getLoad(NewLoadType, DL, LN->getChain(), Base, NewMMO);
1883+
1884+
if (NewLoadType != ResultType) {
1885+
SDValue ValRes = DAG.getNode(ISD::ANY_EXTEND, DL, ResultType, Result);
1886+
return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL);
1887+
}
1888+
1889+
return Result;
1890+
}
1891+
1892+
report_fatal_error(
1893+
"encountered unexpected ExtType when loading from webassembly global",
1894+
false);
17771895
}
17781896

17791897
if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) {
@@ -3637,6 +3755,184 @@ static SDValue performMulCombine(SDNode *N,
36373755
}
36383756
}
36393757

3758+
static SDValue performANDCombine(SDNode *N,
3759+
TargetLowering::DAGCombinerInfo &DCI) {
3760+
// Copied and modified from DAGCombiner::visitAND(SDNode *N)
3761+
// We have to do this because the original combiner doesn't work when ZEXTLOAD
3762+
// has custom lowering
3763+
3764+
SDValue N0 = N->getOperand(0);
3765+
SDValue N1 = N->getOperand(1);
3766+
SDLoc DL(N);
3767+
3768+
// fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
3769+
// (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
3770+
// already be zero by virtue of the width of the base type of the load.
3771+
//
3772+
// the 'X' node here can either be nothing or an extract_vector_elt to catch
3773+
// more cases.
3774+
if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
3775+
N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
3776+
N0.getOperand(0).getOpcode() == ISD::LOAD &&
3777+
N0.getOperand(0).getResNo() == 0) ||
3778+
(N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
3779+
auto *Load =
3780+
cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0));
3781+
3782+
// Get the constant (if applicable) the zero'th operand is being ANDed with.
3783+
// This can be a pure constant or a vector splat, in which case we treat the
3784+
// vector as a scalar and use the splat value.
3785+
APInt Constant = APInt::getZero(1);
3786+
if (const ConstantSDNode *C = isConstOrConstSplat(
3787+
N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) {
3788+
Constant = C->getAPIntValue();
3789+
} else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
3790+
unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
3791+
APInt SplatValue, SplatUndef;
3792+
unsigned SplatBitSize;
3793+
bool HasAnyUndefs;
3794+
// Endianness should not matter here. Code below makes sure that we only
3795+
// use the result if the SplatBitSize is a multiple of the vector element
3796+
// size. And after that we AND all element sized parts of the splat
3797+
// together. So the end result should be the same regardless of in which
3798+
// order we do those operations.
3799+
const bool IsBigEndian = false;
3800+
bool IsSplat =
3801+
Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
3802+
HasAnyUndefs, EltBitWidth, IsBigEndian);
3803+
3804+
// Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
3805+
// multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
3806+
if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
3807+
// Undef bits can contribute to a possible optimisation if set, so
3808+
// set them.
3809+
SplatValue |= SplatUndef;
3810+
3811+
// The splat value may be something like "0x00FFFFFF", which means 0 for
3812+
// the first vector value and FF for the rest, repeating. We need a mask
3813+
// that will apply equally to all members of the vector, so AND all the
3814+
// lanes of the constant together.
3815+
Constant = APInt::getAllOnes(EltBitWidth);
3816+
for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
3817+
Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
3818+
}
3819+
}
3820+
3821+
// If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
3822+
// actually legal and isn't going to get expanded, else this is a false
3823+
// optimisation.
3824+
3825+
/*bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
3826+
Load->getValueType(0),
3827+
Load->getMemoryVT());*/
3828+
// MODIFIED: this is the one difference in the logic; we allow ZEXT combine
3829+
// only in addrspace 0, where it's legal
3830+
bool CanZextLoadProfitably = Load->getAddressSpace() == 0;
3831+
3832+
// Resize the constant to the same size as the original memory access before
3833+
// extension. If it is still the AllOnesValue then this AND is completely
3834+
// unneeded.
3835+
Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());
3836+
3837+
bool B;
3838+
switch (Load->getExtensionType()) {
3839+
default:
3840+
B = false;
3841+
break;
3842+
case ISD::EXTLOAD:
3843+
B = CanZextLoadProfitably;
3844+
break;
3845+
case ISD::ZEXTLOAD:
3846+
case ISD::NON_EXTLOAD:
3847+
B = true;
3848+
break;
3849+
}
3850+
3851+
if (B && Constant.isAllOnes()) {
3852+
// If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
3853+
// preserve semantics once we get rid of the AND.
3854+
SDValue NewLoad(Load, 0);
3855+
3856+
// Fold the AND away. NewLoad may get replaced immediately.
3857+
DCI.CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);
3858+
3859+
if (Load->getExtensionType() == ISD::EXTLOAD) {
3860+
NewLoad = DCI.DAG.getLoad(
3861+
Load->getAddressingMode(), ISD::ZEXTLOAD, Load->getValueType(0),
3862+
SDLoc(Load), Load->getChain(), Load->getBasePtr(),
3863+
Load->getOffset(), Load->getMemoryVT(), Load->getMemOperand());
3864+
// Replace uses of the EXTLOAD with the new ZEXTLOAD.
3865+
if (Load->getNumValues() == 3) {
3866+
// PRE/POST_INC loads have 3 values.
3867+
SDValue To[] = {NewLoad.getValue(0), NewLoad.getValue(1),
3868+
NewLoad.getValue(2)};
3869+
DCI.CombineTo(Load, ArrayRef<SDValue>(To, 3), true);
3870+
} else {
3871+
DCI.CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
3872+
}
3873+
}
3874+
3875+
return SDValue(N, 0); // Return N so it doesn't get rechecked!
3876+
}
3877+
}
3878+
return SDValue();
3879+
}
3880+
3881+
static SDValue
3882+
performSIGN_EXTEND_INREGCombine(SDNode *N,
3883+
TargetLowering::DAGCombinerInfo &DCI) {
3884+
// Copied and modified from DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N)
3885+
// We have to do this because the original combiner doesn't work when SEXTLOAD
3886+
// has custom lowering
3887+
3888+
SDValue N0 = N->getOperand(0);
3889+
SDValue N1 = N->getOperand(1);
3890+
EVT VT = N->getValueType(0);
3891+
EVT ExtVT = cast<VTSDNode>(N1)->getVT();
3892+
SDLoc DL(N);
3893+
3894+
// fold (sext_inreg (extload x)) -> (sextload x)
3895+
// If sextload is not supported by target, we can only do the combine when
3896+
// load has one use. Doing otherwise can block folding the extload with other
3897+
// extends that the target does support.
3898+
3899+
// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with
3900+
// cast<LoadSDNode>(N0)->getAddressSpace() == 0)
3901+
if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
3902+
ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
3903+
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple() &&
3904+
N0.hasOneUse()) ||
3905+
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) {
3906+
auto *LN0 = cast<LoadSDNode>(N0);
3907+
SDValue ExtLoad =
3908+
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
3909+
LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
3910+
DCI.CombineTo(N, ExtLoad);
3911+
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
3912+
DCI.AddToWorklist(ExtLoad.getNode());
3913+
return SDValue(N, 0); // Return N so it doesn't get rechecked!
3914+
}
3915+
3916+
// fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
3917+
3918+
// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with
3919+
// cast<LoadSDNode>(N0)->getAddressSpace() == 0)
3920+
if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
3921+
N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
3922+
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple()) &&
3923+
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) {
3924+
auto *LN0 = cast<LoadSDNode>(N0);
3925+
SDValue ExtLoad =
3926+
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
3927+
LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
3928+
DCI.CombineTo(N, ExtLoad);
3929+
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
3930+
return SDValue(N, 0); // Return N so it doesn't get rechecked!
3931+
}
3932+
3933+
return SDValue();
3934+
}
3935+
36403936
SDValue
36413937
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
36423938
DAGCombinerInfo &DCI) const {
@@ -3672,5 +3968,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
36723968
}
36733969
case ISD::MUL:
36743970
return performMulCombine(N, DCI);
3971+
case ISD::AND:
3972+
return performANDCombine(N, DCI);
3973+
case ISD::SIGN_EXTEND_INREG:
3974+
return performSIGN_EXTEND_INREGCombine(N, DCI);
36753975
}
36763976
}

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
8989
bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
9090
bool isVarArg,
9191
const SmallVectorImpl<ISD::OutputArg> &Outs,
92-
LLVMContext &Context,
93-
const Type *RetTy) const override;
92+
LLVMContext &Context, const Type *RetTy) const override;
9493
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
9594
const SmallVectorImpl<ISD::OutputArg> &Outs,
9695
const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,

0 commit comments

Comments
 (0)