diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index 1e83cbeac50d6..aa604ee8cb2c9 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -9,12 +9,20 @@ tablegen(LLVM WebAssemblyGenDisassemblerTables.inc -gen-disassembler) tablegen(LLVM WebAssemblyGenFastISel.inc -gen-fast-isel) tablegen(LLVM WebAssemblyGenInstrInfo.inc -gen-instr-info) tablegen(LLVM WebAssemblyGenMCCodeEmitter.inc -gen-emitter) +tablegen(LLVM WebAssemblyGenRegisterBank.inc -gen-register-bank) tablegen(LLVM WebAssemblyGenRegisterInfo.inc -gen-register-info) tablegen(LLVM WebAssemblyGenSubtargetInfo.inc -gen-subtarget) +set(LLVM_TARGET_DEFINITIONS WebAssemblyGISel.td) +tablegen(LLVM WebAssemblyGenGlobalISel.inc -gen-global-isel) + add_public_tablegen_target(WebAssemblyCommonTableGen) add_llvm_target(WebAssemblyCodeGen + GISel/WebAssemblyCallLowering.cpp + GISel/WebAssemblyInstructionSelector.cpp + GISel/WebAssemblyRegisterBankInfo.cpp + GISel/WebAssemblyLegalizerInfo.cpp WebAssemblyAddMissingPrototypes.cpp WebAssemblyArgumentMove.cpp WebAssemblyAsmPrinter.cpp diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp new file mode 100644 index 0000000000000..43ba7b1a983aa --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -0,0 +1,1246 @@ +//===-- WebAssemblyCallLowering.cpp - Call lowering for GlobalISel -*- C++ -*-// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the lowering of LLVM calls to machine code calls for +/// GlobalISel. +/// +//===----------------------------------------------------------------------===// + +#include "WebAssemblyCallLowering.h" +#include "GISel/WebAssemblyRegisterBankInfo.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "Utils/WasmAddressSpaces.h" +#include "WebAssemblyISelLowering.h" +#include "WebAssemblyMachineFunctionInfo.h" +#include "WebAssemblyRegisterInfo.h" +#include "WebAssemblySubtarget.h" +#include "WebAssemblyUtilities.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/CodeGen/Analysis.h" +#include "llvm/CodeGen/CallingConvLower.h" +#include "llvm/CodeGen/FunctionLoweringInfo.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/GlobalISel/Utils.h" +#include "llvm/CodeGen/LowLevelTypeUtils.h" +#include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineMemOperand.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGenTypes/LowLevelType.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" + +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/MC/MCSymbolWasm.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include + +#define DEBUG_TYPE "wasm-call-lowering" + +using namespace llvm; + +// Several of the following methods are internal utilities defined in +// CodeGen/GlobalIsel/CallLowering.cpp +// TODO: Find a better solution? + +// Internal utility from CallLowering.cpp +static unsigned extendOpFromFlags(ISD::ArgFlagsTy Flags) { + if (Flags.isSExt()) + return TargetOpcode::G_SEXT; + if (Flags.isZExt()) + return TargetOpcode::G_ZEXT; + return TargetOpcode::G_ANYEXT; +} + +// Internal utility from CallLowering.cpp +/// Pack values \p SrcRegs to cover the vector type result \p DstRegs. +static MachineInstrBuilder +mergeVectorRegsToResultRegs(MachineIRBuilder &B, ArrayRef DstRegs, + ArrayRef SrcRegs) { + MachineRegisterInfo &MRI = *B.getMRI(); + LLT LLTy = MRI.getType(DstRegs[0]); + LLT PartLLT = MRI.getType(SrcRegs[0]); + + // Deal with v3s16 split into v2s16 + LLT LCMTy = getCoverTy(LLTy, PartLLT); + if (LCMTy == LLTy) { + // Common case where no padding is needed. + assert(DstRegs.size() == 1); + return B.buildConcatVectors(DstRegs[0], SrcRegs); + } + + // We need to create an unmerge to the result registers, which may require + // widening the original value. + Register UnmergeSrcReg; + if (LCMTy != PartLLT) { + assert(DstRegs.size() == 1); + return B.buildDeleteTrailingVectorElements( + DstRegs[0], B.buildMergeLikeInstr(LCMTy, SrcRegs)); + } else { + // We don't need to widen anything if we're extracting a scalar which was + // promoted to a vector e.g. s8 -> v4s8 -> s8 + assert(SrcRegs.size() == 1); + UnmergeSrcReg = SrcRegs[0]; + } + + int NumDst = LCMTy.getSizeInBits() / LLTy.getSizeInBits(); + + SmallVector PadDstRegs(NumDst); + llvm::copy(DstRegs, PadDstRegs.begin()); + + // Create the excess dead defs for the unmerge. + for (int I = DstRegs.size(); I != NumDst; ++I) + PadDstRegs[I] = MRI.createGenericVirtualRegister(LLTy); + + if (PadDstRegs.size() == 1) + return B.buildDeleteTrailingVectorElements(DstRegs[0], UnmergeSrcReg); + return B.buildUnmerge(PadDstRegs, UnmergeSrcReg); +} + +// Internal utility from CallLowering.cpp +/// Create a sequence of instructions to combine pieces split into register +/// typed values to the original IR value. \p OrigRegs contains the destination +/// value registers of type \p LLTy, and \p Regs contains the legalized pieces +/// with type \p PartLLT. This is used for incoming values (physregs to vregs). + +// Modified to account for floating-point extends/truncations +static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef OrigRegs, + ArrayRef Regs, LLT LLTy, LLT PartLLT, + const ISD::ArgFlagsTy Flags, + bool IsFloatingPoint) { + MachineRegisterInfo &MRI = *B.getMRI(); + + if (PartLLT == LLTy) { + // We should have avoided introducing a new virtual register, and just + // directly assigned here. + assert(OrigRegs[0] == Regs[0]); + return; + } + + if (PartLLT.getSizeInBits() == LLTy.getSizeInBits() && OrigRegs.size() == 1 && + Regs.size() == 1) { + B.buildBitcast(OrigRegs[0], Regs[0]); + return; + } + + // A vector PartLLT needs extending to LLTy's element size. + // E.g. <2 x s64> = G_SEXT <2 x s32>. + if (PartLLT.isVector() == LLTy.isVector() && + PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() && + (!PartLLT.isVector() || + PartLLT.getElementCount() == LLTy.getElementCount()) && + OrigRegs.size() == 1 && Regs.size() == 1) { + Register SrcReg = Regs[0]; + + LLT LocTy = MRI.getType(SrcReg); + + if (Flags.isSExt()) { + SrcReg = B.buildAssertSExt(LocTy, SrcReg, LLTy.getScalarSizeInBits()) + .getReg(0); + } else if (Flags.isZExt()) { + SrcReg = B.buildAssertZExt(LocTy, SrcReg, LLTy.getScalarSizeInBits()) + .getReg(0); + } + + // Sometimes pointers are passed zero extended. + LLT OrigTy = MRI.getType(OrigRegs[0]); + if (OrigTy.isPointer()) { + LLT IntPtrTy = LLT::scalar(OrigTy.getSizeInBits()); + B.buildIntToPtr(OrigRegs[0], B.buildTrunc(IntPtrTy, SrcReg)); + return; + } + + if (IsFloatingPoint) + B.buildFPTrunc(OrigRegs[0], SrcReg); + else + B.buildTrunc(OrigRegs[0], SrcReg); + return; + } + + if (!LLTy.isVector() && !PartLLT.isVector()) { + assert(OrigRegs.size() == 1); + LLT OrigTy = MRI.getType(OrigRegs[0]); + + unsigned SrcSize = PartLLT.getSizeInBits().getFixedValue() * Regs.size(); + if (SrcSize == OrigTy.getSizeInBits()) + B.buildMergeValues(OrigRegs[0], Regs); + else { + auto Widened = B.buildMergeLikeInstr(LLT::scalar(SrcSize), Regs); + + if (IsFloatingPoint) + B.buildFPTrunc(OrigRegs[0], Widened); + else + B.buildTrunc(OrigRegs[0], Widened); + } + + return; + } + + if (PartLLT.isVector()) { + assert(OrigRegs.size() == 1); + SmallVector CastRegs(Regs); + + // If PartLLT is a mismatched vector in both number of elements and element + // size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to + // have the same elt type, i.e. v4s32. + // TODO: Extend this coersion to element multiples other than just 2. + if (TypeSize::isKnownGT(PartLLT.getSizeInBits(), LLTy.getSizeInBits()) && + PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 && + Regs.size() == 1) { + LLT NewTy = PartLLT.changeElementType(LLTy.getElementType()) + .changeElementCount(PartLLT.getElementCount() * 2); + CastRegs[0] = B.buildBitcast(NewTy, Regs[0]).getReg(0); + PartLLT = NewTy; + } + + if (LLTy.getScalarType() == PartLLT.getElementType()) { + mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs); + } else { + unsigned I = 0; + LLT GCDTy = getGCDType(LLTy, PartLLT); + + // We are both splitting a vector, and bitcasting its element types. Cast + // the source pieces into the appropriate number of pieces with the result + // element type. + for (Register SrcReg : CastRegs) + CastRegs[I++] = B.buildBitcast(GCDTy, SrcReg).getReg(0); + mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs); + } + + return; + } + + assert(LLTy.isVector() && !PartLLT.isVector()); + + LLT DstEltTy = LLTy.getElementType(); + + // Pointer information was discarded. We'll need to coerce some register types + // to avoid violating type constraints. + LLT RealDstEltTy = MRI.getType(OrigRegs[0]).getElementType(); + + assert(DstEltTy.getSizeInBits() == RealDstEltTy.getSizeInBits()); + + if (DstEltTy == PartLLT) { + // Vector was trivially scalarized. + + if (RealDstEltTy.isPointer()) { + for (Register Reg : Regs) + MRI.setType(Reg, RealDstEltTy); + } + + B.buildBuildVector(OrigRegs[0], Regs); + } else if (DstEltTy.getSizeInBits() > PartLLT.getSizeInBits()) { + // Deal with vector with 64-bit elements decomposed to 32-bit + // registers. Need to create intermediate 64-bit elements. + SmallVector EltMerges; + int PartsPerElt = + divideCeil(DstEltTy.getSizeInBits(), PartLLT.getSizeInBits()); + LLT ExtendedPartTy = LLT::scalar(PartLLT.getSizeInBits() * PartsPerElt); + + for (int I = 0, NumElts = LLTy.getNumElements(); I != NumElts; ++I) { + auto Merge = + B.buildMergeLikeInstr(ExtendedPartTy, Regs.take_front(PartsPerElt)); + if (ExtendedPartTy.getSizeInBits() > RealDstEltTy.getSizeInBits()) + Merge = B.buildTrunc(RealDstEltTy, Merge); + // Fix the type in case this is really a vector of pointers. + MRI.setType(Merge.getReg(0), RealDstEltTy); + EltMerges.push_back(Merge.getReg(0)); + Regs = Regs.drop_front(PartsPerElt); + } + + B.buildBuildVector(OrigRegs[0], EltMerges); + } else { + // Vector was split, and elements promoted to a wider type. + // FIXME: Should handle floating point promotions. + unsigned NumElts = LLTy.getNumElements(); + LLT BVType = LLT::fixed_vector(NumElts, PartLLT); + + Register BuildVec; + if (NumElts == Regs.size()) + BuildVec = B.buildBuildVector(BVType, Regs).getReg(0); + else { + // Vector elements are packed in the inputs. + // e.g. we have a <4 x s16> but 2 x s32 in regs. + assert(NumElts > Regs.size()); + LLT SrcEltTy = MRI.getType(Regs[0]); + + LLT OriginalEltTy = MRI.getType(OrigRegs[0]).getElementType(); + + // Input registers contain packed elements. + // Determine how many elements per reg. + assert((SrcEltTy.getSizeInBits() % OriginalEltTy.getSizeInBits()) == 0); + unsigned EltPerReg = + (SrcEltTy.getSizeInBits() / OriginalEltTy.getSizeInBits()); + + SmallVector BVRegs; + BVRegs.reserve(Regs.size() * EltPerReg); + for (Register R : Regs) { + auto Unmerge = B.buildUnmerge(OriginalEltTy, R); + for (unsigned K = 0; K < EltPerReg; ++K) + BVRegs.push_back(B.buildAnyExt(PartLLT, Unmerge.getReg(K)).getReg(0)); + } + + // We may have some more elements in BVRegs, e.g. if we have 2 s32 pieces + // for a <3 x s16> vector. We should have less than EltPerReg extra items. + if (BVRegs.size() > NumElts) { + assert((BVRegs.size() - NumElts) < EltPerReg); + BVRegs.truncate(NumElts); + } + BuildVec = B.buildBuildVector(BVType, BVRegs).getReg(0); + } + B.buildTrunc(OrigRegs[0], BuildVec); + } +} + +// Internal utility from CallLowering.cpp +/// Create a sequence of instructions to expand the value in \p SrcReg (of type +/// \p SrcTy) to the types in \p DstRegs (of type \p PartTy). \p ExtendOp should +/// contain the type of scalar value extension if necessary. +/// +/// This is used for outgoing values (vregs to physregs) +static void buildCopyToRegs(MachineIRBuilder &B, ArrayRef DstRegs, + Register SrcReg, LLT SrcTy, LLT PartTy, + unsigned ExtendOp = TargetOpcode::G_ANYEXT) { + // We could just insert a regular copy, but this is unreachable at the moment. + assert(SrcTy != PartTy && "identical part types shouldn't reach here"); + + const TypeSize PartSize = PartTy.getSizeInBits(); + + if (PartTy.isVector() == SrcTy.isVector() && + PartTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits()) { + assert(DstRegs.size() == 1); + B.buildInstr(ExtendOp, {DstRegs[0]}, {SrcReg}); + return; + } + + if (SrcTy.isVector() && !PartTy.isVector() && + TypeSize::isKnownGT(PartSize, SrcTy.getElementType().getSizeInBits())) { + // Vector was scalarized, and the elements extended. + auto UnmergeToEltTy = B.buildUnmerge(SrcTy.getElementType(), SrcReg); + for (int i = 0, e = DstRegs.size(); i != e; ++i) + B.buildAnyExt(DstRegs[i], UnmergeToEltTy.getReg(i)); + return; + } + + if (SrcTy.isVector() && PartTy.isVector() && + PartTy.getSizeInBits() == SrcTy.getSizeInBits() && + ElementCount::isKnownLT(SrcTy.getElementCount(), + PartTy.getElementCount())) { + // A coercion like: v2f32 -> v4f32 or nxv2f32 -> nxv4f32 + Register DstReg = DstRegs.front(); + B.buildPadVectorWithUndefElements(DstReg, SrcReg); + return; + } + + LLT GCDTy = getGCDType(SrcTy, PartTy); + if (GCDTy == PartTy) { + // If this already evenly divisible, we can create a simple unmerge. + B.buildUnmerge(DstRegs, SrcReg); + return; + } + + if (SrcTy.isVector() && !PartTy.isVector() && + SrcTy.getScalarSizeInBits() > PartTy.getSizeInBits()) { + LLT ExtTy = + LLT::vector(SrcTy.getElementCount(), + LLT::scalar(PartTy.getScalarSizeInBits() * DstRegs.size() / + SrcTy.getNumElements())); + auto Ext = B.buildAnyExt(ExtTy, SrcReg); + B.buildUnmerge(DstRegs, Ext); + return; + } + + MachineRegisterInfo &MRI = *B.getMRI(); + LLT DstTy = MRI.getType(DstRegs[0]); + LLT LCMTy = getCoverTy(SrcTy, PartTy); + + if (PartTy.isVector() && LCMTy == PartTy) { + assert(DstRegs.size() == 1); + B.buildPadVectorWithUndefElements(DstRegs[0], SrcReg); + return; + } + + const unsigned DstSize = DstTy.getSizeInBits(); + const unsigned SrcSize = SrcTy.getSizeInBits(); + unsigned CoveringSize = LCMTy.getSizeInBits(); + + Register UnmergeSrc = SrcReg; + + if (!LCMTy.isVector() && CoveringSize != SrcSize) { + // For scalars, it's common to be able to use a simple extension. + if (SrcTy.isScalar() && DstTy.isScalar()) { + CoveringSize = alignTo(SrcSize, DstSize); + LLT CoverTy = LLT::scalar(CoveringSize); + UnmergeSrc = B.buildInstr(ExtendOp, {CoverTy}, {SrcReg}).getReg(0); + } else { + // Widen to the common type. + // FIXME: This should respect the extend type + Register Undef = B.buildUndef(SrcTy).getReg(0); + SmallVector MergeParts(1, SrcReg); + for (unsigned Size = SrcSize; Size != CoveringSize; Size += SrcSize) + MergeParts.push_back(Undef); + UnmergeSrc = B.buildMergeLikeInstr(LCMTy, MergeParts).getReg(0); + } + } + + if (LCMTy.isVector() && CoveringSize != SrcSize) + UnmergeSrc = B.buildPadVectorWithUndefElements(LCMTy, SrcReg).getReg(0); + + B.buildUnmerge(DstRegs, UnmergeSrc); +} + +// Test whether the given calling convention is supported. +static bool callingConvSupported(CallingConv::ID CallConv) { + // We currently support the language-independent target-independent + // conventions. We don't yet have a way to annotate calls with properties like + // "cold", and we don't have any call-clobbered registers, so these are mostly + // all handled the same. + return CallConv == CallingConv::C || CallConv == CallingConv::Fast || + CallConv == CallingConv::Cold || + CallConv == CallingConv::PreserveMost || + CallConv == CallingConv::PreserveAll || + CallConv == CallingConv::CXX_FAST_TLS || + CallConv == CallingConv::WASM_EmscriptenInvoke || + CallConv == CallingConv::Swift; +} + +static void fail(MachineIRBuilder &MIRBuilder, const char *Msg) { + MachineFunction &MF = MIRBuilder.getMF(); + MIRBuilder.getContext().diagnose( + DiagnosticInfoUnsupported(MF.getFunction(), Msg, MIRBuilder.getDL())); +} + +WebAssemblyCallLowering::WebAssemblyCallLowering( + const WebAssemblyTargetLowering &TLI) + : CallLowering(&TLI) {} + +bool WebAssemblyCallLowering::canLowerReturn(MachineFunction &MF, + CallingConv::ID CallConv, + SmallVectorImpl &Outs, + bool IsVarArg) const { + return WebAssembly::canLowerReturn(Outs.size(), + &MF.getSubtarget()); +} + +bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, + const Value *Val, + ArrayRef VRegs, + FunctionLoweringInfo &FLI, + Register SwiftErrorVReg) const { + auto MIB = MIRBuilder.buildInstrNoInsert(WebAssembly::RETURN); + MachineFunction &MF = MIRBuilder.getMF(); + auto &Subtarget = MF.getSubtarget(); + auto &RBI = *Subtarget.getRegBankInfo(); + + assert(((Val && !VRegs.empty()) || (!Val && VRegs.empty())) && + "Return value without a vreg"); + + if (Val && !FLI.CanLowerReturn) { + insertSRetStores(MIRBuilder, Val->getType(), VRegs, FLI.DemoteRegister); + } else if (!VRegs.empty()) { + MachineFunction &MF = MIRBuilder.getMF(); + const Function &F = MF.getFunction(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const WebAssemblyTargetLowering &TLI = *getTLI(); + auto &DL = F.getDataLayout(); + LLVMContext &Ctx = Val->getType()->getContext(); + + SmallVector SplitEVTs; + ComputeValueVTs(TLI, DL, Val->getType(), SplitEVTs); + assert(VRegs.size() == SplitEVTs.size() && + "For each split Type there should be exactly one VReg."); + + SmallVector SplitArgs; + CallingConv::ID CallConv = F.getCallingConv(); + + unsigned i = 0; + for (auto SplitEVT : SplitEVTs) { + Register CurVReg = VRegs[i]; + ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0}; + setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, F); + + splitToValueTypes(CurArgInfo, SplitArgs, DL, CallConv); + ++i; + } + + for (auto &Arg : SplitArgs) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Arg.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + const TargetRegisterClass &NewRegClass = *TLI.getRegClassFor(NewVT); + + // If we need to split the type over multiple regs, check it's a scenario + // we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Arg.Flags[0]; + Arg.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Arg.Flags.push_back(Flags); + } + + Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Arg.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT, + Arg.Ty->isFloatingPointTy() + ? TargetOpcode::G_FPEXT + : extendOpFromFlags(Arg.Flags[0])); + } + + for (unsigned Part = 0; Part < NumParts; ++Part) { + auto NewOutReg = Arg.Regs[Part]; + if (!RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI)) { + NewOutReg = MRI.createGenericVirtualRegister(NewLLT); + assert(RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI) && + "Couldn't constrain brand-new register?"); + MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]); + } + MIB.addUse(NewOutReg); + } + } + } + + if (SwiftErrorVReg) { + llvm_unreachable("WASM does not `supportSwiftError`, yet SwiftErrorVReg is " + "improperly valid."); + } + + MIRBuilder.insertInstr(MIB); + return true; +} + +static unsigned getWASMArgOpcode(MVT ArgType) { +#define MVT_CASE(type) \ + case MVT::type: \ + return WebAssembly::ARGUMENT_##type; + + switch (ArgType.SimpleTy) { + MVT_CASE(i32) + MVT_CASE(i64) + MVT_CASE(f32) + MVT_CASE(f64) + MVT_CASE(funcref) + MVT_CASE(externref) + MVT_CASE(exnref) + MVT_CASE(v16i8) + MVT_CASE(v8i16) + MVT_CASE(v4i32) + MVT_CASE(v2i64) + MVT_CASE(v4f32) + MVT_CASE(v2f64) + MVT_CASE(v8f16) + default: + break; + } + llvm_unreachable("Found unexpected type for WASM argument"); + +#undef MVT_CASE +} + +bool WebAssemblyCallLowering::lowerFormalArguments( + MachineIRBuilder &MIRBuilder, const Function &F, + ArrayRef> VRegs, FunctionLoweringInfo &FLI) const { + + MachineFunction &MF = MIRBuilder.getMF(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + WebAssemblyFunctionInfo *MFI = MF.getInfo(); + const DataLayout &DL = F.getDataLayout(); + auto &TLI = *getTLI(); + auto &Subtarget = MF.getSubtarget(); + auto &TRI = *Subtarget.getRegisterInfo(); + auto &TII = *Subtarget.getInstrInfo(); + auto &RBI = *Subtarget.getRegBankInfo(); + + LLVMContext &Ctx = MIRBuilder.getContext(); + const CallingConv::ID CallConv = F.getCallingConv(); + + if (!callingConvSupported(F.getCallingConv())) { + fail(MIRBuilder, "WebAssembly doesn't support non-C calling conventions"); + return false; + } + + // Set up the live-in for the incoming ARGUMENTS. + MF.getRegInfo().addLiveIn(WebAssembly::ARGUMENTS); + + SmallVector SplitArgs; + + if (!FLI.CanLowerReturn) { + insertSRetIncomingArgument(F, SplitArgs, FLI.DemoteRegister, MRI, DL); + } + unsigned i = 0; + + bool HasSwiftErrorArg = false; + bool HasSwiftSelfArg = false; + for (const auto &Arg : F.args()) { + ArgInfo OrigArg{VRegs[i], Arg.getType(), i}; + setArgFlags(OrigArg, i + AttributeList::FirstArgIndex, DL, F); + + HasSwiftSelfArg |= Arg.hasSwiftSelfAttr(); + HasSwiftErrorArg |= Arg.hasSwiftErrorAttr(); + if (Arg.hasInAllocaAttr()) { + fail(MIRBuilder, "WebAssembly hasn't implemented inalloca arguments"); + return false; + } + if (Arg.hasNestAttr()) { + fail(MIRBuilder, "WebAssembly hasn't implemented nest arguments"); + return false; + } + splitToValueTypes(OrigArg, SplitArgs, DL, F.getCallingConv()); + ++i; + } + + unsigned FinalArgIdx = 0; + for (auto &Arg : SplitArgs) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Arg.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + + // If we need to split the type over multiple regs, check it's a scenario + // we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Arg.Flags[0]; + Arg.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Arg.Flags.push_back(Flags); + } + + Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Arg.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + buildCopyFromRegs(MIRBuilder, Arg.OrigRegs, Arg.Regs, OrigLLT, NewLLT, + Arg.Flags[0], Arg.Ty->isFloatingPointTy()); + } + + for (unsigned Part = 0; Part < NumParts; ++Part) { + auto ArgInst = MIRBuilder.buildInstr(getWASMArgOpcode(NewVT)) + .addDef(Arg.Regs[Part]) + .addImm(FinalArgIdx); + + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *ArgInst, + ArgInst->getDesc(), ArgInst->getOperand(0), 0); + MFI->addParam(NewVT); + ++FinalArgIdx; + } + } + + /**/ + + // For swiftcc, emit additional swiftself and swifterror arguments + // if there aren't. These additional arguments are also added for callee + // signature They are necessary to match callee and caller signature for + // indirect call. + auto PtrVT = TLI.getPointerTy(DL); + if (CallConv == CallingConv::Swift) { + if (!HasSwiftSelfArg) { + MFI->addParam(PtrVT); + } + if (!HasSwiftErrorArg) { + MFI->addParam(PtrVT); + } + } + + // Varargs are copied into a buffer allocated by the caller, and a pointer to + // the buffer is passed as an argument. + if (F.isVarArg()) { + auto PtrVT = TLI.getPointerTy(DL); + Register VarargVreg = MF.getRegInfo().createGenericVirtualRegister( + getLLTForType(*PointerType::get(Ctx, 0), DL)); + MFI->setVarargBufferVreg(VarargVreg); + + auto ArgInst = MIRBuilder.buildInstr(getWASMArgOpcode(PtrVT)) + .addDef(VarargVreg) + .addImm(FinalArgIdx); + + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *ArgInst, + ArgInst->getDesc(), ArgInst->getOperand(0), 0); + + MFI->addParam(PtrVT); + ++FinalArgIdx; + } + + // Record the number and types of arguments and results. + SmallVector Params; + SmallVector Results; + computeSignatureVTs(MF.getFunction().getFunctionType(), &MF.getFunction(), + MF.getFunction(), MF.getTarget(), Params, Results); + for (MVT VT : Results) + MFI->addResult(VT); + + // TODO: Use signatures in WebAssemblyMachineFunctionInfo too and unify + // the param logic here with ComputeSignatureVTs + assert(MFI->getParams().size() == Params.size() && + std::equal(MFI->getParams().begin(), MFI->getParams().end(), + Params.begin())); + return true; +} + +bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, + CallLoweringInfo &Info) const { + MachineFunction &MF = MIRBuilder.getMF(); + auto DL = MIRBuilder.getDataLayout(); + LLVMContext &Ctx = MIRBuilder.getContext(); + const WebAssemblyTargetLowering &TLI = *getTLI(); + MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); + const WebAssemblySubtarget &Subtarget = + MF.getSubtarget(); + auto &TRI = *Subtarget.getRegisterInfo(); + auto &TII = *Subtarget.getInstrInfo(); + auto &RBI = *Subtarget.getRegBankInfo(); + + CallingConv::ID CallConv = Info.CallConv; + if (!callingConvSupported(CallConv)) { + fail(MIRBuilder, + "WebAssembly doesn't support language-specific or target-specific " + "calling conventions yet"); + return false; + } + + // TODO: investigate "PatchPoint" + /* + if (Info.IsPatchPoint) { + fail(MIRBuilder, "WebAssembly doesn't support patch point yet"); + return false; + } + */ + + if (Info.IsTailCall) { + Info.LoweredTailCall = true; + auto NoTail = [&](const char *Msg) { + if (Info.CB && Info.CB->isMustTailCall()) + fail(MIRBuilder, Msg); + Info.LoweredTailCall = false; + }; + + if (!Subtarget.hasTailCall()) + NoTail("WebAssembly 'tail-call' feature not enabled"); + + // Varargs calls cannot be tail calls because the buffer is on the stack + if (Info.IsVarArg) + NoTail("WebAssembly does not support varargs tail calls"); + + // Do not tail call unless caller and callee return types match + const Function &F = MF.getFunction(); + const TargetMachine &TM = TLI.getTargetMachine(); + Type *RetTy = F.getReturnType(); + SmallVector CallerRetTys; + SmallVector CalleeRetTys; + computeLegalValueVTs(F, TM, RetTy, CallerRetTys); + computeLegalValueVTs(F, TM, Info.OrigRet.Ty, CalleeRetTys); + bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() && + std::equal(CallerRetTys.begin(), CallerRetTys.end(), + CalleeRetTys.begin()); + if (!TypesMatch) + NoTail("WebAssembly tail call requires caller and callee return types to " + "match"); + + // If pointers to local stack values are passed, we cannot tail call + if (Info.CB) { + for (auto &Arg : Info.CB->args()) { + Value *Val = Arg.get(); + // Trace the value back through pointer operations + while (true) { + Value *Src = Val->stripPointerCastsAndAliases(); + if (auto *GEP = dyn_cast(Src)) + Src = GEP->getPointerOperand(); + if (Val == Src) + break; + Val = Src; + } + if (isa(Val)) { + NoTail( + "WebAssembly does not support tail calling with stack arguments"); + break; + } + } + } + } + + if (Info.LoweredTailCall) { + MF.getFrameInfo().setHasTailCall(); + } + + MachineInstrBuilder CallInst; + + bool IsIndirect = false; + Register IndirectIdx; + + if (Info.Callee.isReg()) { + IsIndirect = true; + CallInst = MIRBuilder.buildInstr(Info.LoweredTailCall + ? WebAssembly::RET_CALL_INDIRECT + : WebAssembly::CALL_INDIRECT); + } else { + CallInst = MIRBuilder.buildInstr( + Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL); + } + + if (!Info.LoweredTailCall) { + if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) { + SmallVector SplitEVTs; + ComputeValueVTs(TLI, DL, Info.OrigRet.Ty, SplitEVTs); + assert(Info.OrigRet.Regs.size() == SplitEVTs.size() && + "For each split Type there should be exactly one VReg."); + + SmallVector SplitReturns; + + unsigned i = 0; + for (auto SplitEVT : SplitEVTs) { + Register CurVReg = Info.OrigRet.Regs[i]; + ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0}; + if (Info.CB) { + setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB); + } else { + // we don't have a call base, so chances are we're looking at a + // libcall (external symbol). + + // TODO: figure out how to get ALL the correct attributes + auto &Flags = CurArgInfo.Flags[0]; + PointerType *PtrTy = + dyn_cast(CurArgInfo.Ty->getScalarType()); + if (PtrTy) { + Flags.setPointer(); + Flags.setPointerAddrSpace(PtrTy->getPointerAddressSpace()); + } + Align MemAlign = DL.getABITypeAlign(CurArgInfo.Ty); + Flags.setMemAlign(MemAlign); + Flags.setOrigAlign(MemAlign); + } + splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv); + ++i; + } + + for (auto &Ret : SplitReturns) { + EVT OrigVT = TLI.getValueType(DL, Ret.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Ret.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + const TargetRegisterClass &NewRegClass = *TLI.getRegClassFor(NewVT); + + // If we need to split the type over multiple regs, check it's a + // scenario we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Ret.Flags[0]; + Ret.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Ret.Flags.push_back(Flags); + } + + Ret.OrigRegs.assign(Ret.Regs.begin(), Ret.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Ret.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT, + Ret.Flags[0], Ret.Ty->isFloatingPointTy()); + } + + for (unsigned Part = 0; Part < NumParts; ++Part) { + auto NewRetReg = Ret.Regs[Part]; + if (!RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI)) { + NewRetReg = MRI.createGenericVirtualRegister(NewLLT); + assert(RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI) && + "Couldn't constrain brand-new register?"); + MIRBuilder.buildCopy(NewRetReg, Ret.Regs[Part]); + } + CallInst.addDef(Ret.Regs[Part]); + } + } + } + + if (!Info.CanLowerReturn) { + insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs, + Info.DemoteRegister, Info.DemoteStackIndex); + } + } + auto SavedInsertPt = MIRBuilder.getInsertPt(); + MIRBuilder.setInstr(*CallInst); + + if (Info.Callee.isReg()) { + LLT CalleeType = MRI.getType(Info.Callee.getReg()); + assert(CalleeType.isPointer() && + "Trying to lower a call with a Callee other than a pointer???"); + + // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp + CallInst.addImm(0); + + MCSymbolWasm *Table; + if (CalleeType.getAddressSpace() == + WebAssembly::WASM_ADDRESS_SPACE_DEFAULT) { + Table = WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), + &Subtarget); + IndirectIdx = Info.Callee.getReg(); + + auto PtrSize = CalleeType.getSizeInBits(); + auto PtrIntLLT = LLT::scalar(PtrSize); + + IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0); + } else if (CalleeType.getAddressSpace() == + WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) { + Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(MF.getContext(), + &Subtarget); + + Type *PtrTy = PointerType::getUnqual(Ctx); + LLT PtrLLT = getLLTForType(*PtrTy, DL); + auto PtrIntLLT = LLT::scalar(PtrLLT.getSizeInBits()); + + IndirectIdx = MIRBuilder.buildConstant(PtrIntLLT, 0).getReg(0); + + auto TableSetInstr = + MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF); + TableSetInstr.addSym(Table); + TableSetInstr.addUse(IndirectIdx); + TableSetInstr.addUse(Info.Callee.getReg()); + + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *TableSetInstr, + TableSetInstr->getDesc(), + TableSetInstr->getOperand(1), 1); + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *TableSetInstr, + TableSetInstr->getDesc(), + TableSetInstr->getOperand(2), 2); + + } else { + fail(MIRBuilder, "Invalid address space for indirect call"); + return false; + } + + if (Subtarget.hasCallIndirectOverlong()) { + CallInst.addSym(Table); + } else { + // For the MVP there is at most one table whose number is 0, but we can't + // write a table symbol or issue relocations. Instead we just ensure the + // table is live and write a zero. + Table->setNoStrip(); + CallInst.addImm(0); + } + } else { + if (Info.Callee.isGlobal()) { + CallInst.addGlobalAddress(Info.Callee.getGlobal()); + } else if (Info.Callee.isSymbol()) { + CallInst.addExternalSymbol(Info.Callee.getSymbolName()); + } else { + llvm_unreachable("Trying to lower call with a callee other than reg, " + "global, or a symbol."); + } + } + + SmallVector SplitArgs; + + bool HasSwiftErrorArg = false; + bool HasSwiftSelfArg = false; + + for (const auto &Arg : Info.OrigArgs) { + HasSwiftSelfArg |= Arg.Flags[0].isSwiftSelf(); + HasSwiftErrorArg |= Arg.Flags[0].isSwiftError(); + if (Arg.Flags[0].isNest()) { + fail(MIRBuilder, "WebAssembly hasn't implemented nest arguments"); + return false; + } + if (Arg.Flags[0].isInAlloca()) { + fail(MIRBuilder, "WebAssembly hasn't implemented inalloca arguments"); + return false; + } + if (Arg.Flags[0].isInConsecutiveRegs()) { + fail(MIRBuilder, "WebAssembly hasn't implemented cons regs arguments"); + return false; + } + if (Arg.Flags[0].isInConsecutiveRegsLast()) { + fail(MIRBuilder, + "WebAssembly hasn't implemented cons regs last arguments"); + return false; + } + + if (Arg.Flags[0].isByVal() && Arg.Flags[0].getByValSize() != 0) { + MachineFrameInfo &MFI = MF.getFrameInfo(); + + unsigned MemSize = Arg.Flags[0].getByValSize(); + Align MemAlign = Arg.Flags[0].getNonZeroByValAlign(); + int FI = MFI.CreateStackObject(Arg.Flags[0].getByValSize(), MemAlign, + /*isSS=*/false); + + auto StackAddrSpace = DL.getAllocaAddrSpace(); + auto PtrLLT = + LLT::pointer(StackAddrSpace, DL.getPointerSizeInBits(StackAddrSpace)); + + Register StackObjPtrVreg = + MF.getRegInfo().createGenericVirtualRegister(PtrLLT); + MRI.setRegClass(StackObjPtrVreg, TLI.getRepRegClassFor(TLI.getPointerTy( + DL, StackAddrSpace))); + + MIRBuilder.buildFrameIndex(StackObjPtrVreg, FI); + + MachinePointerInfo DstPtrInfo = MachinePointerInfo::getFixedStack(MF, FI); + + MachinePointerInfo SrcPtrInfo(Arg.OrigValue); + if (!Arg.OrigValue) { + // We still need to accurately track the stack address space if we + // don't know the underlying value. + SrcPtrInfo = MachinePointerInfo::getUnknownStack(MF); + } + + Align DstAlign = + std::max(MemAlign, inferAlignFromPtrInfo(MF, DstPtrInfo)); + + Align SrcAlign = + std::max(MemAlign, inferAlignFromPtrInfo(MF, SrcPtrInfo)); + + MachineMemOperand *SrcMMO = MF.getMachineMemOperand( + SrcPtrInfo, + MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable, + MemSize, SrcAlign); + + MachineMemOperand *DstMMO = MF.getMachineMemOperand( + DstPtrInfo, + MachineMemOperand::MOStore | MachineMemOperand::MODereferenceable, + MemSize, DstAlign); + + const LLT SizeTy = LLT::scalar(PtrLLT.getSizeInBits()); + + auto SizeConst = MIRBuilder.buildConstant(SizeTy, MemSize); + MIRBuilder.buildMemCpy(StackObjPtrVreg, Arg.Regs[0], SizeConst, *DstMMO, + *SrcMMO); + } + + splitToValueTypes(Arg, SplitArgs, DL, CallConv); + } + + unsigned NumFixedArgs = 0; + + for (auto &Arg : SplitArgs) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Arg.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + + // If we need to split the type over multiple regs, check it's a scenario + // we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Arg.Flags[0]; + Arg.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Arg.Flags.push_back(Flags); + } + + Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Arg.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + + buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT, + Arg.Ty->isFloatingPointTy() + ? TargetOpcode::G_FPEXT + : extendOpFromFlags(Arg.Flags[0])); + } + + if (!Arg.Flags[0].isVarArg()) { + for (unsigned Part = 0; Part < NumParts; ++Part) { + auto NewArgReg = constrainRegToClass(MRI, TII, RBI, Arg.Regs[Part], + *TLI.getRegClassFor(NewVT)); + if (Arg.Regs[Part] != NewArgReg) + MIRBuilder.buildCopy(NewArgReg, Arg.Regs[Part]); + CallInst.addUse(Arg.Regs[Part]); + } + ++NumFixedArgs; + } + } + + if (CallConv == CallingConv::Swift) { + Type *PtrTy = PointerType::getUnqual(Ctx); + LLT PtrLLT = getLLTForType(*PtrTy, DL); + auto &PtrRegClass = *TLI.getRegClassFor(TLI.getSimpleValueType(DL, PtrTy)); + + if (!HasSwiftSelfArg) { + auto NewUndefReg = MIRBuilder.buildUndef(PtrLLT).getReg(0); + MRI.setRegClass(NewUndefReg, &PtrRegClass); + CallInst.addUse(NewUndefReg); + } + if (!HasSwiftErrorArg) { + auto NewUndefReg = MIRBuilder.buildUndef(PtrLLT).getReg(0); + MRI.setRegClass(NewUndefReg, &PtrRegClass); + CallInst.addUse(NewUndefReg); + } + } + + // Analyze operands of the call, assigning locations to each operand. + SmallVector ArgLocs; + CCState CCInfo(CallConv, Info.IsVarArg, MF, ArgLocs, Ctx); + + if (Info.IsVarArg) { + // Outgoing non-fixed arguments are placed in a buffer. First + // compute their offsets and the total amount of buffer space needed. + for (ArgInfo &Arg : drop_begin(SplitArgs, NumFixedArgs)) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT PartVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + Type *Ty = EVT(PartVT).getTypeForEVT(Ctx); + + for (unsigned Part = 0; Part < Arg.Regs.size(); ++Part) { + Align Alignment = std::max(Arg.Flags[Part].getNonZeroOrigAlign(), + DL.getABITypeAlign(Ty)); + unsigned Offset = + CCInfo.AllocateStack(DL.getTypeAllocSize(Ty), Alignment); + CCInfo.addLoc(CCValAssign::getMem(ArgLocs.size(), PartVT, Offset, + PartVT, CCValAssign::Full)); + } + } + } + + unsigned NumBytes = CCInfo.getAlignedCallFrameSize(); + + auto StackAddrSpace = DL.getAllocaAddrSpace(); + auto PtrLLT = LLT::pointer(StackAddrSpace, DL.getPointerSizeInBits(0)); + auto SizeLLT = LLT::scalar(DL.getPointerSizeInBits(StackAddrSpace)); + auto *PtrRegClass = TLI.getRegClassFor(TLI.getPointerTy(DL, StackAddrSpace)); + + if (Info.IsVarArg && NumBytes) { + Register VarArgStackPtr = + MF.getRegInfo().createGenericVirtualRegister(PtrLLT); + MRI.setRegClass(VarArgStackPtr, PtrRegClass); + + MaybeAlign StackAlign = DL.getStackAlignment(); + assert(StackAlign && "data layout string is missing stack alignment"); + int FI = MF.getFrameInfo().CreateStackObject(NumBytes, *StackAlign, + /*isSS=*/false); + + MIRBuilder.buildFrameIndex(VarArgStackPtr, FI); + + unsigned ValNo = 0; + for (ArgInfo &Arg : drop_begin(SplitArgs, NumFixedArgs)) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT PartVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + Type *Ty = EVT(PartVT).getTypeForEVT(Ctx); + + for (unsigned Part = 0; Part < Arg.Regs.size(); ++Part) { + Align Alignment = std::max(Arg.Flags[Part].getNonZeroOrigAlign(), + DL.getABITypeAlign(Ty)); + + unsigned Offset = ArgLocs[ValNo++].getLocMemOffset(); + + Register DstPtr = + MIRBuilder + .buildPtrAdd( + PtrLLT, VarArgStackPtr, + MIRBuilder.buildConstant(SizeLLT, Offset).getReg(0)) + .getReg(0); + + MachineMemOperand *DstMMO = MF.getMachineMemOperand( + MachinePointerInfo::getFixedStack(MF, FI, Offset), + MachineMemOperand::MOStore | MachineMemOperand::MODereferenceable, + PartVT.getStoreSize(), Alignment); + + MIRBuilder.buildStore(Arg.Regs[Part], DstPtr, *DstMMO); + } + } + + CallInst.addUse(VarArgStackPtr); + } else if (Info.IsVarArg) { + auto NewArgReg = MIRBuilder.buildConstant(PtrLLT, 0).getReg(0); + MRI.setRegClass(NewArgReg, PtrRegClass); + CallInst.addUse(NewArgReg); + } + + if (IsIndirect) { + auto NewArgReg = + constrainRegToClass(MRI, TII, RBI, IndirectIdx, *PtrRegClass); + if (IndirectIdx != NewArgReg) + MIRBuilder.buildCopy(NewArgReg, IndirectIdx); + CallInst.addUse(IndirectIdx); + } + + MIRBuilder.setInsertPt(MIRBuilder.getMBB(), SavedInsertPt); + + return true; +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h new file mode 100644 index 0000000000000..d22f7cbd17eb3 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h @@ -0,0 +1,43 @@ +//===-- WebAssemblyCallLowering.h - Call lowering for GlobalISel -*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file describes how to lower LLVM calls to machine code calls. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYCALLLOWERING_H +#define LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYCALLLOWERING_H + +#include "WebAssemblyISelLowering.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" +#include "llvm/IR/CallingConv.h" + +namespace llvm { + +class WebAssemblyTargetLowering; + +class WebAssemblyCallLowering : public CallLowering { +public: + WebAssemblyCallLowering(const WebAssemblyTargetLowering &TLI); + + bool canLowerReturn(MachineFunction &MF, CallingConv::ID CallConv, + SmallVectorImpl &Outs, + bool IsVarArg) const override; + bool lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, + ArrayRef VRegs, FunctionLoweringInfo &FLI, + Register SwiftErrorVReg) const override; + bool lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, + ArrayRef> VRegs, + FunctionLoweringInfo &FLI) const override; + bool lowerCall(MachineIRBuilder &MIRBuilder, + CallLoweringInfo &Info) const override; +}; +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp new file mode 100644 index 0000000000000..0ef5f357718ac --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp @@ -0,0 +1,603 @@ +//===- WebAssemblyInstructionSelector.cpp ------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file implements the targeting of the InstructionSelector class for +/// WebAssembly. +/// \todo This should be generated by TableGen. +//===----------------------------------------------------------------------===// + +#include "GISel/WebAssemblyRegisterBankInfo.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "Utils/WasmAddressSpaces.h" +#include "Utils/WebAssemblyTypeUtilities.h" +#include "WebAssemblyRegisterInfo.h" +#include "WebAssemblySubtarget.h" +#include "WebAssemblyTargetMachine.h" +#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/GlobalISel/Utils.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineJumpTableInfo.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/RegisterBank.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "wasm-isel" + +using namespace llvm; + +namespace { + +#define GET_GLOBALISEL_PREDICATE_BITSET +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATE_BITSET + +class WebAssemblyInstructionSelector : public InstructionSelector { +public: + WebAssemblyInstructionSelector(const WebAssemblyTargetMachine &TM, + const WebAssemblySubtarget &STI, + const WebAssemblyRegisterBankInfo &RBI); + + bool select(MachineInstr &I) override; + + InstructionSelector::ComplexRendererFns + selectAddrOperands32(MachineOperand &Root) const; + InstructionSelector::ComplexRendererFns + selectAddrOperands64(MachineOperand &Root) const; + + static const char *getName() { return DEBUG_TYPE; } + +private: + bool selectImpl(MachineInstr &I, CodeGenCoverage &CoverageInfo) const; + bool selectCopy(MachineInstr &I, MachineRegisterInfo &MRI) const; + + InstructionSelector::ComplexRendererFns + selectAddrOperands(LLT AddrType, unsigned int ConstOpc, + MachineOperand &Root) const; + + const WebAssemblyTargetMachine &TM; + const WebAssemblySubtarget &STI; + const WebAssemblyInstrInfo &TII; + const WebAssemblyRegisterInfo &TRI; + const WebAssemblyRegisterBankInfo &RBI; + +#define GET_GLOBALISEL_PREDICATES_DECL +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_DECL + +#define GET_GLOBALISEL_TEMPORARIES_DECL +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_DECL +}; + +} // end anonymous namespace + +#define GET_GLOBALISEL_IMPL +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_IMPL + +WebAssemblyInstructionSelector::WebAssemblyInstructionSelector( + const WebAssemblyTargetMachine &TM, const WebAssemblySubtarget &STI, + const WebAssemblyRegisterBankInfo &RBI) + : TM(TM), STI(STI), TII(*STI.getInstrInfo()), TRI(*STI.getRegisterInfo()), + RBI(RBI), + +#define GET_GLOBALISEL_PREDICATES_INIT +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_INIT +#define GET_GLOBALISEL_TEMPORARIES_INIT +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_INIT +{ +} + +InstructionSelector::ComplexRendererFns +WebAssemblyInstructionSelector::selectAddrOperands(LLT AddrType, + unsigned int ConstOpc, + MachineOperand &Root) const { + + if (!Root.isReg()) + return std::nullopt; + + MachineRegisterInfo &MRI = + Root.getParent()->getParent()->getParent()->getRegInfo(); + MachineInstr &RootDef = *MRI.getVRegDef(Root.getReg()); + + if (RootDef.getOpcode() == TargetOpcode::G_PTR_ADD) { + // RootDef will always be G_PTR_ADD + MachineOperand &LHS = RootDef.getOperand(1); + + MachineOperand &RHS = RootDef.getOperand(2); + MachineInstr &LHSDef = *MRI.getVRegDef(LHS.getReg()); + MachineInstr &RHSDef = + *MRI.getVRegDef(RHS.getReg()); // Will always be G_CONSTANT + + // WebAssembly constant offsets are performed as unsigned with infinite + // precision, so we need to check for NoUnsignedWrap so that we don't fold + // and offset for an add that needs wrapping. + if (RootDef.getFlag(MachineInstr::MIFlag::NoUWrap)) { + for (size_t i = 0; i < 2; ++i) { + // MachineOperand &Op = i == 0 ? LHS : RHS; + MachineInstr &OpDef = i == 0 ? LHSDef : RHSDef; + MachineOperand &OtherOp = i == 0 ? RHS : LHS; + // MachineInstr &OtherOpDef = i == 0 ? RHSDef : LHSDef; + + if (OpDef.getOpcode() == TargetOpcode::G_CONSTANT) { + auto Offset = OpDef.getOperand(1).getCImm()->getZExtValue(); + auto Addr = OtherOp; + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addImm(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.add(Addr); }, + }}; + } + + if (!TM.isPositionIndependent()) { + if (OpDef.getOpcode() == TargetOpcode::G_GLOBAL_VALUE) { + auto Offset = OpDef.getOperand(1).getGlobal(); + auto Addr = OtherOp; + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addGlobalAddress(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.add(Addr); }, + }}; + } + } + } + } + } + + if (RootDef.getOpcode() == TargetOpcode::G_CONSTANT) { + auto Offset = RootDef.getOperand(1).getCImm()->getZExtValue(); + auto Addr = MRI.createGenericVirtualRegister(AddrType); + + MachineIRBuilder B(RootDef); + + auto MIB = B.buildInstr(ConstOpc).addDef(Addr).addImm(0); + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addImm(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.addReg(Addr); }, + }}; + } + + if (!TM.isPositionIndependent() && + RootDef.getOpcode() == TargetOpcode::G_GLOBAL_VALUE) { + auto *Offset = RootDef.getOperand(1).getGlobal(); + auto Addr = MRI.createGenericVirtualRegister(AddrType); + + MachineIRBuilder B(RootDef); + + auto MIB = B.buildInstr(ConstOpc).addDef(Addr).addImm(0); + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addGlobalAddress(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.addReg(Addr); }, + }}; + } + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addImm(0); }, + [=](MachineInstrBuilder &MIB) { MIB.add(Root); }, + }}; +} + +InstructionSelector::ComplexRendererFns +WebAssemblyInstructionSelector::selectAddrOperands32( + MachineOperand &Root) const { + return selectAddrOperands(LLT::scalar(32), WebAssembly::CONST_I32, Root); +} + +InstructionSelector::ComplexRendererFns +WebAssemblyInstructionSelector::selectAddrOperands64( + MachineOperand &Root) const { + return selectAddrOperands(LLT::scalar(64), WebAssembly::CONST_I64, Root); +} + +bool WebAssemblyInstructionSelector::selectCopy( + MachineInstr &I, MachineRegisterInfo &MRI) const { + Register DstReg = I.getOperand(0).getReg(); + Register SrcReg = I.getOperand(1).getReg(); + + const TargetRegisterClass *DstRC; + if (DstReg.isPhysical()) { + switch (DstReg.id()) { + case WebAssembly::SP32: + DstRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::SP64: + DstRC = &WebAssembly::I64RegClass; + break; + default: + llvm_unreachable("Copy to physical register other than SP32 or SP64?"); + } + } else { + DstRC = MRI.getRegClassOrNull(DstReg); + } + + if (!DstRC) { + const RegisterBank *DstBank = MRI.getRegBankOrNull(DstReg); + if (!DstBank) { + llvm_unreachable("Selecting copy with dst reg with no bank?"); + } + + switch (DstBank->getID()) { + case WebAssembly::I32RegBankID: + DstRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::I64RegBankID: + DstRC = &WebAssembly::I64RegClass; + break; + case WebAssembly::F32RegBankID: + DstRC = &WebAssembly::F32RegClass; + break; + case WebAssembly::F64RegBankID: + DstRC = &WebAssembly::F64RegClass; + break; + default: + llvm_unreachable("Unknown reg bank to reg class mapping?"); + } + if (!constrainOperandRegClass(*MF, TRI, MRI, TII, RBI, I, *DstRC, + I.getOperand(0))) { + LLVM_DEBUG(dbgs() << "Failed to constrain " << TII.getName(I.getOpcode()) + << " operand\n"); + return false; + } + } + + const TargetRegisterClass *SrcRC; + if (SrcReg.isPhysical()) { + switch (SrcReg.id()) { + case WebAssembly::SP32: + SrcRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::SP64: + SrcRC = &WebAssembly::I64RegClass; + break; + default: + llvm_unreachable("Copy to physical register other than SP32 or SP64?"); + } + } else { + SrcRC = MRI.getRegClassOrNull(SrcReg); + } + if (!SrcRC) { + const RegisterBank *SrcBank = MRI.getRegBankOrNull(SrcReg); + if (!SrcBank) { + llvm_unreachable("Selecting copy with src reg with no bank?"); + } + + switch (SrcBank->getID()) { + case WebAssembly::I32RegBankID: + SrcRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::I64RegBankID: + SrcRC = &WebAssembly::I64RegClass; + break; + case WebAssembly::F32RegBankID: + SrcRC = &WebAssembly::F32RegClass; + break; + case WebAssembly::F64RegBankID: + SrcRC = &WebAssembly::F64RegClass; + break; + default: + llvm_unreachable("Unknown reg bank to reg class mapping?"); + } + if (!constrainOperandRegClass(*MF, TRI, MRI, TII, RBI, I, *SrcRC, + I.getOperand(1))) { + LLVM_DEBUG(dbgs() << "Failed to constrain " << TII.getName(I.getOpcode()) + << " operand\n"); + return false; + } + } + + assert(TRI.getRegSizeInBits(*DstRC) == TRI.getRegSizeInBits(*SrcRC) && + "Copy between mismatching register sizes?"); + + if (DstRC != SrcRC) { + if (DstRC == &WebAssembly::I32RegClass && + SrcRC == &WebAssembly::F32RegClass) { + I.setDesc(TII.get(WebAssembly::I32_REINTERPRET_F32)); + return true; + } + + if (DstRC == &WebAssembly::F32RegClass && + SrcRC == &WebAssembly::I32RegClass) { + I.setDesc(TII.get(WebAssembly::F32_REINTERPRET_I32)); + return true; + } + + if (DstRC == &WebAssembly::I64RegClass && + SrcRC == &WebAssembly::F64RegClass) { + I.setDesc(TII.get(WebAssembly::I64_REINTERPRET_F64)); + return true; + } + + if (DstRC == &WebAssembly::F64RegClass && + SrcRC == &WebAssembly::I64RegClass) { + I.setDesc(TII.get(WebAssembly::F64_REINTERPRET_I64)); + return true; + } + + llvm_unreachable("Found bitcast/copy edge case."); + } + + return true; +} + +static const TargetRegisterClass * +getRegClassForTypeOnBank(const RegisterBank &RB) { + switch (RB.getID()) { + case WebAssembly::I32RegBankID: + return &WebAssembly::I32RegClass; + case WebAssembly::I64RegBankID: + return &WebAssembly::I64RegClass; + case WebAssembly::F32RegBankID: + return &WebAssembly::F32RegClass; + case WebAssembly::F64RegBankID: + return &WebAssembly::F64RegClass; + case WebAssembly::EXNREFRegBankID: + return &WebAssembly::EXNREFRegClass; + case WebAssembly::EXTERNREFRegBankID: + return &WebAssembly::EXTERNREFRegClass; + case WebAssembly::FUNCREFRegBankID: + return &WebAssembly::FUNCREFRegClass; + // case WebAssembly::V128RegBankID: + // return &WebAssembly::V128RegClass; + } + + return nullptr; +} + +bool WebAssemblyInstructionSelector::select(MachineInstr &I) { + MachineBasicBlock &MBB = *I.getParent(); + MachineFunction &MF = *MBB.getParent(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const TargetLowering &TLI = *STI.getTargetLowering(); + + if (!I.isPreISelOpcode() || I.getOpcode() == TargetOpcode::G_PHI) { + if (I.getOpcode() == TargetOpcode::PHI || + I.getOpcode() == TargetOpcode::G_PHI) { + const Register DefReg = I.getOperand(0).getReg(); + const LLT DefTy = MRI.getType(DefReg); + + const RegClassOrRegBank &RegClassOrBank = + MRI.getRegClassOrRegBank(DefReg); + + const TargetRegisterClass *DefRC = + dyn_cast(RegClassOrBank); + + if (!DefRC) { + if (!DefTy.isValid()) { + LLVM_DEBUG(dbgs() << "PHI operand has no type, not a gvreg?\n"); + return false; + } + const RegisterBank &RB = *cast(RegClassOrBank); + DefRC = getRegClassForTypeOnBank(RB); + if (!DefRC) { + LLVM_DEBUG(dbgs() << "PHI operand has unexpected size/bank\n"); + return false; + } + } + + I.setDesc(TII.get(TargetOpcode::PHI)); + + return RBI.constrainGenericRegister(DefReg, *DefRC, MRI) != nullptr; + } + } + + if (!isPreISelGenericOpcode(I.getOpcode())) { + if (I.isCopy()) + return selectCopy(I, MRI); + + return true; + } + + if (selectImpl(I, *CoverageInfo)) + return true; + + using namespace TargetOpcode; + + auto PointerWidth = MF.getDataLayout().getPointerSizeInBits(); + auto PtrIsI64 = PointerWidth == 64; + + switch (I.getOpcode()) { + case G_IMPLICIT_DEF: { + const Register DefReg = I.getOperand(0).getReg(); + const LLT DefTy = MRI.getType(DefReg); + + const RegClassOrRegBank &RegClassOrBank = MRI.getRegClassOrRegBank(DefReg); + + const TargetRegisterClass *DefRC = + dyn_cast(RegClassOrBank); + + if (!DefRC) { + if (!DefTy.isValid()) { + LLVM_DEBUG( + dbgs() << "IMPLICIT_DEF operand has no type, not a gvreg?\n"); + return false; + } + const RegisterBank &RB = *cast(RegClassOrBank); + DefRC = getRegClassForTypeOnBank(RB); + if (!DefRC) { + LLVM_DEBUG(dbgs() << "IMPLICIT_DEF operand has unexpected size/bank\n"); + return false; + } + } + + I.setDesc(TII.get(TargetOpcode::IMPLICIT_DEF)); + + return RBI.constrainGenericRegister(DefReg, *DefRC, MRI) != nullptr; + return true; + } + case G_BRJT: { + auto JT = I.getOperand(1); + auto Index = I.getOperand(2); + + assert(JT.getTargetFlags() == 0 && "WebAssembly doesn't set target flags"); + + MachineIRBuilder B(I); + + MachineJumpTableInfo *MJTI = MF.getJumpTableInfo(); + const auto &MBBs = MJTI->getJumpTables()[JT.getIndex()].MBBs; + + auto MIB = B.buildInstr(PtrIsI64 ? WebAssembly::BR_TABLE_I64 + : WebAssembly::BR_TABLE_I32) + .add(Index); + + for (auto *MBB : MBBs) + MIB.addMBB(MBB); + + // Add the first MBB as a dummy default target for now. This will be + // replaced with the proper default target (and the preceding range check + // eliminated) if possible by WebAssemblyFixBrTableDefaults. + MIB.addMBB(*MBBs.begin()); + + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + I.eraseFromParent(); + return true; + } + case G_PTR_ADD: { + assert(MRI.getType(I.getOperand(0).getReg()).isPointer() && + "G_PTR_ADD selection fell-through with non-pointer?"); + + auto PointerWidth = MF.getDataLayout().getPointerSizeInBits(); + I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::ADD_I64 + : WebAssembly::ADD_I32)); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + case G_PTRTOINT: { + assert(MRI.getType(I.getOperand(1).getReg()).isPointer() && + "G_PTRTOINT selection fell-through with non-pointer?"); + + I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::COPY_I64 + : WebAssembly::COPY_I32)); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + case G_INTTOPTR: { + assert(MRI.getType(I.getOperand(0).getReg()).isPointer() && + "G_INTTOPTR selection fell-through with non-pointer?"); + + I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::COPY_I64 + : WebAssembly::COPY_I32)); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + case G_FRAME_INDEX: { + MachineIRBuilder B(I); + + I.setDesc( + TII.get(PtrIsI64 ? WebAssembly::COPY_I64 : WebAssembly::COPY_I32)); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + return true; + } + case G_GLOBAL_VALUE: + assert(I.getOperand(1).getTargetFlags() == 0 && + "Unexpected target flags on generic G_GLOBAL_VALUE instruction"); + assert(WebAssembly::isValidAddressSpace( + MRI.getType(I.getOperand(0).getReg()).getAddressSpace()) && + "Invalid address space for WebAssembly target"); + + unsigned OperandFlags = 0; + const llvm::GlobalValue *GV = I.getOperand(1).getGlobal(); + // Since WebAssembly tables cannot yet be shared accross modules, we don't + // need special treatment for tables in PIC mode. + if (TLI.isPositionIndependent() && + !WebAssembly::isWebAssemblyTableType(GV->getValueType())) { + if (TM.shouldAssumeDSOLocal(GV)) { + const char *BaseName; + if (GV->getValueType()->isFunctionTy()) { + BaseName = MF.createExternalSymbolName("__table_base"); + OperandFlags = WebAssemblyII::MO_TABLE_BASE_REL; + } else { + BaseName = MF.createExternalSymbolName("__memory_base"); + OperandFlags = WebAssemblyII::MO_MEMORY_BASE_REL; + } + MachineIRBuilder B(I); + + auto MemBase = + MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); + MRI.setRegClass(MemBase, PtrIsI64 ? &WebAssembly::I64RegClass + : &WebAssembly::I32RegClass); + auto Offset = + MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); + MRI.setRegClass(Offset, PtrIsI64 ? &WebAssembly::I64RegClass + : &WebAssembly::I32RegClass); + + B.buildInstr(PtrIsI64 ? WebAssembly::GLOBAL_GET_I64 + : WebAssembly::GLOBAL_GET_I32) + .addDef(MemBase) + .addExternalSymbol(BaseName); + + B.buildInstr(PtrIsI64 ? WebAssembly::CONST_I64 : WebAssembly::CONST_I32) + .addDef(Offset) + .addGlobalAddress(GV, I.getOperand(1).getOffset(), OperandFlags); + + auto MIB = + B.buildInstr(PtrIsI64 ? WebAssembly::ADD_I64 : WebAssembly::ADD_I32) + .addDef(I.getOperand(0).getReg()) + .addReg(MemBase) + .addReg(Offset); + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + I.eraseFromParent(); + return true; + } + OperandFlags = WebAssemblyII::MO_GOT; + } + + auto NewOpc = MF.getDataLayout().getPointerSizeInBits() == 64 + ? WebAssembly::CONST_I64 + : WebAssembly::CONST_I32; + + if (OperandFlags & WebAssemblyII::MO_GOT) { + NewOpc = MF.getDataLayout().getPointerSizeInBits() == 64 + ? WebAssembly::GLOBAL_GET_I64 + : WebAssembly::GLOBAL_GET_I32; + } + + I.setDesc(TII.get(NewOpc)); + I.getOperand(1).setTargetFlags(OperandFlags); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + + return false; +} + +namespace llvm { +InstructionSelector * +createWebAssemblyInstructionSelector(const WebAssemblyTargetMachine &TM, + const WebAssemblySubtarget &Subtarget, + const WebAssemblyRegisterBankInfo &RBI) { + return new WebAssemblyInstructionSelector(TM, Subtarget, RBI); +} +} // namespace llvm diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp new file mode 100644 index 0000000000000..3e9d5957a22bc --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -0,0 +1,435 @@ +//===- WebAssemblyLegalizerInfo.h --------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file implements the targeting of the Machinelegalizer class for +/// WebAssembly +//===----------------------------------------------------------------------===// + +#include "WebAssemblyLegalizerInfo.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "WebAssemblySubtarget.h" +#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/DerivedTypes.h" + +#define DEBUG_TYPE "wasm-legalinfo" + +using namespace llvm; +using namespace LegalizeActions; + +WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( + const WebAssemblySubtarget &ST) { + using namespace TargetOpcode; + const LLT s8 = LLT::scalar(8); + const LLT s16 = LLT::scalar(16); + const LLT s32 = LLT::scalar(32); + const LLT s64 = LLT::scalar(64); + + const LLT p0 = LLT::pointer(0, ST.hasAddr64() ? 64 : 32); + const LLT p0s = LLT::scalar(ST.hasAddr64() ? 64 : 32); + + getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor({p0}); + + getActionDefinitionsBuilder(G_PHI) + .legalFor({p0, s32, s64}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + getActionDefinitionsBuilder(G_BR).alwaysLegal(); + getActionDefinitionsBuilder(G_BRCOND).legalFor({s32}).clampScalar(0, s32, + s32); + getActionDefinitionsBuilder(G_BRJT).legalFor({{p0, p0s}}); + + getActionDefinitionsBuilder(G_SELECT) + .legalFor({{s32, s32}, {s64, s32}, {p0, s32}}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s32); + + getActionDefinitionsBuilder(G_JUMP_TABLE).legalFor({p0}); + + getActionDefinitionsBuilder(G_ICMP) + .legalFor({{s32, s32}, {s32, s64}, {s32, p0}}) + .widenScalarToNextPow2(1) + .clampScalar(1, s32, s64) + .clampScalar(0, s32, s32); + + getActionDefinitionsBuilder(G_FCMP) + .customFor({{s32, s32}, {s32, s64}}) + .clampScalar(0, s32, s32) + .libcall(); + + getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + + getActionDefinitionsBuilder(G_CONSTANT) + .legalFor({s32, s64, p0}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder(G_FCONSTANT) + .legalFor({s32, s64}) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder(G_IMPLICIT_DEF) + .legalFor({s32, s64, p0}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder( + {G_ADD, G_SUB, G_MUL, G_UDIV, G_SDIV, G_UREM, G_SREM}) + .legalFor({s32, s64}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF, + G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP}) + .legalFor({{s32, s32}, {s64, s64}}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64) + .minScalarSameAs(1, 0) + .maxScalarSameAs(1, 0); + + getActionDefinitionsBuilder({G_FSHL, G_FSHR}) + .legalFor({{s32, s32}, {s64, s64}}) + .lower(); + + getActionDefinitionsBuilder({G_SCMP, G_UCMP}).lower(); + + getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) + .legalFor({s32, s64}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder({G_UMIN, G_UMAX, G_SMIN, G_SMAX}).lower(); + + getActionDefinitionsBuilder({G_FADD, G_FSUB, G_FDIV, G_FMUL, G_FNEG, G_FABS, + G_FCEIL, G_FFLOOR, G_FSQRT, G_INTRINSIC_TRUNC, + G_FNEARBYINT, G_FRINT, G_INTRINSIC_ROUNDEVEN, + G_FMINIMUM, G_FMAXIMUM}) + .legalFor({s32, s64}) + .minScalar(0, s32); + + // TODO: _IEEE not lowering correctly? + getActionDefinitionsBuilder( + {G_FMINNUM, G_FMAXNUM, G_FMINNUM_IEEE, G_FMAXNUM_IEEE}) + .lowerFor({s32, s64}) + .minScalar(0, s32); + + getActionDefinitionsBuilder({G_FMA, G_FREM}) + .libcallFor({s32, s64}) + .minScalar(0, s32); + + getActionDefinitionsBuilder(G_LROUND).libcallForCartesianProduct({s32}, + {s32, s64}); + + getActionDefinitionsBuilder(G_LLROUND).libcallForCartesianProduct({s64}, + {s32, s64}); + + getActionDefinitionsBuilder(G_FCOPYSIGN) + .legalFor({s32, s64}) + .minScalar(0, s32) + .minScalarSameAs(1, 0) + .maxScalarSameAs(1, 0); + + getActionDefinitionsBuilder({G_FPTOUI, G_FPTOUI_SAT, G_FPTOSI, G_FPTOSI_SAT}) + .legalForCartesianProduct({s32, s64}, {s32, s64}) + .minScalar(1, s32) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder({G_UITOFP, G_SITOFP}) + .legalForCartesianProduct({s32, s64}, {s32, s64}) + .minScalar(1, s32) + .widenScalarToNextPow2(1) + .clampScalar(1, s32, s64); + + getActionDefinitionsBuilder(G_PTRTOINT) + .legalFor({p0s, p0}) + .customForCartesianProduct({s32, s64}, {p0}); + getActionDefinitionsBuilder(G_INTTOPTR) + .legalFor({p0, p0s}) + .customForCartesianProduct({p0}, {s32, s64}); + getActionDefinitionsBuilder(G_PTR_ADD).legalFor({{p0, p0s}}); + + getActionDefinitionsBuilder(G_LOAD) + .legalForTypesWithMemDesc( + {{s32, p0, s32, 1}, {s64, p0, s64, 1}, {p0, p0, p0, 1}}) + .legalForTypesWithMemDesc({{s32, p0, s8, 1}, + {s32, p0, s16, 1}, + + {s64, p0, s8, 1}, + {s64, p0, s16, 1}, + {s64, p0, s32, 1}}) + .clampScalar(0, s32, s64) + .lowerIfMemSizeNotByteSizePow2(); + + getActionDefinitionsBuilder(G_STORE) + .legalForTypesWithMemDesc( + {{s32, p0, s32, 1}, {s64, p0, s64, 1}, {p0, p0, p0, 1}}) + .legalForTypesWithMemDesc({{s32, p0, s8, 1}, + {s32, p0, s16, 1}, + + {s64, p0, s8, 1}, + {s64, p0, s16, 1}, + {s64, p0, s32, 1}}) + .clampScalar(0, s32, s64) + .lowerIfMemSizeNotByteSizePow2(); + + getActionDefinitionsBuilder({G_ZEXTLOAD, G_SEXTLOAD}) + .legalForTypesWithMemDesc({{s32, p0, s8, 1}, + {s32, p0, s16, 1}, + + {s64, p0, s8, 1}, + {s64, p0, s16, 1}, + {s64, p0, s32, 1}}) + .clampScalar(0, s32, s64) + .lowerIfMemSizeNotByteSizePow2(); + + if (ST.hasBulkMemoryOpt()) { + getActionDefinitionsBuilder(G_BZERO).unsupported(); + + getActionDefinitionsBuilder(G_MEMSET) + .legalForCartesianProduct({p0}, {s32}, {p0s}) + .customForCartesianProduct({p0}, {s8}, {p0s}) + .immIdx(0); + + getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) + .legalForCartesianProduct({p0}, {p0}, {p0s}) + .immIdx(0); + + getActionDefinitionsBuilder(G_MEMCPY_INLINE) + .legalForCartesianProduct({p0}, {p0}, {p0s}); + } else { + getActionDefinitionsBuilder({G_BZERO, G_MEMCPY, G_MEMMOVE, G_MEMSET}) + .libcall(); + } + + // TODO: figure out how to combine G_ANYEXT of G_ASSERT_{S|Z}EXT (or + // appropriate G_AND and G_SEXT_IN_REG?) to a G_{S|Z}EXT + G_ASSERT_{S|Z}EXT + // for better optimization (since G_ANYEXT will lower to a ZEXT or SEXT + // instruction anyway). + + getActionDefinitionsBuilder(G_ANYEXT) + .legalFor({{s64, s32}}) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s64); + + getActionDefinitionsBuilder({G_SEXT, G_ZEXT}) + .legalFor({{s64, s32}}) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s64) + .lower(); + + if (ST.hasSignExt()) { + getActionDefinitionsBuilder(G_SEXT_INREG) + .clampScalar(0, s32, s64) + .customFor({s32, s64}); + } else { + getActionDefinitionsBuilder(G_SEXT_INREG).lower(); + } + + getActionDefinitionsBuilder(G_TRUNC) + .legalFor({{s32, s64}}) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s64) + .lower(); + + getActionDefinitionsBuilder(G_FPEXT).legalFor({{s64, s32}}); + + getActionDefinitionsBuilder(G_FPTRUNC).legalFor({{s32, s64}}); + + getActionDefinitionsBuilder(G_VASTART).legalFor({p0}); + getActionDefinitionsBuilder(G_VAARG) + .legalForCartesianProduct({s32, s64}, {p0}) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder(G_DYN_STACKALLOC).lowerFor({{p0, p0s}}); + + getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).lower(); + + getLegacyLegalizerInfo().computeTables(); +} + +bool WebAssemblyLegalizerInfo::legalizeCustom( + LegalizerHelper &Helper, MachineInstr &MI, + LostDebugLocObserver &LocObserver) const { + auto &MRI = *Helper.MIRBuilder.getMRI(); + auto &MIRBuilder = Helper.MIRBuilder; + + switch (MI.getOpcode()) { + case WebAssembly::G_PTRTOINT: { + auto TmpReg = MRI.createGenericVirtualRegister( + LLT::scalar(MIRBuilder.getDataLayout().getPointerSizeInBits(0))); + + MIRBuilder.buildPtrToInt(TmpReg, MI.getOperand(1)); + MIRBuilder.buildAnyExtOrTrunc(MI.getOperand(0), TmpReg); + MI.eraseFromParent(); + return true; + } + case WebAssembly::G_INTTOPTR: { + auto TmpReg = MRI.createGenericVirtualRegister( + LLT::scalar(MIRBuilder.getDataLayout().getPointerSizeInBits(0))); + + MIRBuilder.buildAnyExtOrTrunc(TmpReg, MI.getOperand(1)); + MIRBuilder.buildIntToPtr(MI.getOperand(0), TmpReg); + MI.eraseFromParent(); + return true; + } + case TargetOpcode::G_FCMP: { + Register LHS = MI.getOperand(2).getReg(); + Register RHS = MI.getOperand(3).getReg(); + CmpInst::Predicate Cond = + static_cast(MI.getOperand(1).getPredicate()); + + auto CmpWidth = MRI.getType(LHS).getSizeInBits(); + assert(CmpWidth == MRI.getType(RHS).getSizeInBits() && + "LHS and RHS for FCMP are diffrent lengths???"); + + switch (Cond) { + case CmpInst::FCMP_FALSE: + return false; + case CmpInst::FCMP_OEQ: + return true; + case CmpInst::FCMP_OGT: + return true; + case CmpInst::FCMP_OGE: + return true; + case CmpInst::FCMP_OLT: + return true; + case CmpInst::FCMP_OLE: + return true; + case CmpInst::FCMP_ONE: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegC = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, TmpRegA, LHS, RHS); + MIRBuilder.buildFCmp(CmpInst::FCMP_OLT, TmpRegB, LHS, RHS); + MIRBuilder.buildOr(TmpRegC, TmpRegA, TmpRegB); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegC); + break; + } + case CmpInst::FCMP_ORD: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegC = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OEQ, TmpRegA, LHS, LHS); + MIRBuilder.buildFCmp(CmpInst::FCMP_OEQ, TmpRegB, RHS, RHS); + MIRBuilder.buildAnd(TmpRegC, TmpRegA, TmpRegB); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegC); + break; + } + case CmpInst::FCMP_UNO: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegC = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_UNE, TmpRegA, LHS, LHS); + MIRBuilder.buildFCmp(CmpInst::FCMP_UNE, TmpRegB, RHS, RHS); + MIRBuilder.buildOr(TmpRegC, TmpRegA, TmpRegB); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegC); + break; + } + case CmpInst::FCMP_UEQ: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_ONE, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_UGT: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OLE, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_UGE: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OLT, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_ULT: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OGE, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_ULE: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_UNE: + return true; + case CmpInst::FCMP_TRUE: + return false; + default: + llvm_unreachable("Unknown FCMP predicate"); + } + + MI.eraseFromParent(); + + return true; + } + case TargetOpcode::G_SEXT_INREG: { + assert(MI.getOperand(2).isImm() && "Expected immediate"); + + // Mark only 8/16/32-bit SEXT_INREG as legal + auto [DstReg, SrcReg] = MI.getFirst2Regs(); + auto DstType = MRI.getType(DstReg); + auto ExtFromWidth = MI.getOperand(2).getImm(); + + if (ExtFromWidth == 8 || ExtFromWidth == 16 || + (DstType.getScalarSizeInBits() == 64 && ExtFromWidth == 32)) { + return true; + } + + Register TmpRes = MRI.createGenericVirtualRegister(DstType); + + auto MIBSz = MIRBuilder.buildConstant( + DstType, DstType.getScalarSizeInBits() - ExtFromWidth); + MIRBuilder.buildShl(TmpRes, SrcReg, MIBSz->getOperand(0)); + MIRBuilder.buildAShr(DstReg, TmpRes, MIBSz->getOperand(0)); + MI.eraseFromParent(); + + return true; + } + case TargetOpcode::G_MEMSET: { + // Anyext the value being set to 32 bit (only the bottom 8 bits are read by + // the instruction). + Helper.Observer.changingInstr(MI); + auto &Value = MI.getOperand(1); + + Register ExtValueReg = + Helper.MIRBuilder.buildAnyExt(LLT::scalar(32), Value).getReg(0); + Value.setReg(ExtValueReg); + Helper.Observer.changedInstr(MI); + return true; + } + default: + break; + } + return false; +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h new file mode 100644 index 0000000000000..5aca23c9514e1 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h @@ -0,0 +1,31 @@ +//===- WebAssemblyLegalizerInfo.h --------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file declares the targeting of the Machinelegalizer class for +/// WebAssembly +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYMACHINELEGALIZER_H +#define LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYMACHINELEGALIZER_H + +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" + +namespace llvm { + +class WebAssemblySubtarget; + +/// This class provides the information for the BPF target legalizer for +/// GlobalISel. +class WebAssemblyLegalizerInfo : public LegalizerInfo { +public: + WebAssemblyLegalizerInfo(const WebAssemblySubtarget &ST); + + bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; +}; +} // namespace llvm +#endif diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp new file mode 100644 index 0000000000000..096cd2125ec22 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp @@ -0,0 +1,425 @@ +#include "WebAssemblyRegisterBankInfo.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "WebAssemblySubtarget.h" +#include "WebAssemblyTargetMachine.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/Support/ErrorHandling.h" + +#define GET_TARGET_REGBANK_IMPL + +#include "WebAssemblyGenRegisterBank.inc" + +namespace llvm { +namespace WebAssembly { +enum PartialMappingIdx { + PMI_None = -1, + PMI_I32 = 1, + PMI_I64, + PMI_F32, + PMI_F64, + PMI_Min = PMI_I32, +}; + +enum ValueMappingIdx { + InvalidIdx = 0, + I32Idx = 1, + I64Idx = 4, + F32Idx = 7, + F64Idx = 10 +}; + +const RegisterBankInfo::PartialMapping PartMappings[]{{0, 32, I32RegBank}, + {0, 64, I64RegBank}, + {0, 32, F32RegBank}, + {0, 64, F64RegBank}}; + +const RegisterBankInfo::ValueMapping ValueMappings[] = { + // invalid + {nullptr, 0}, + // up to 3 operands as I32 + {&PartMappings[PMI_I32 - PMI_Min], 1}, + {&PartMappings[PMI_I32 - PMI_Min], 1}, + {&PartMappings[PMI_I32 - PMI_Min], 1}, + // up to 3 operands as I64 + {&PartMappings[PMI_I64 - PMI_Min], 1}, + {&PartMappings[PMI_I64 - PMI_Min], 1}, + {&PartMappings[PMI_I64 - PMI_Min], 1}, + // up to 3 operands as F32 + {&PartMappings[PMI_F32 - PMI_Min], 1}, + {&PartMappings[PMI_F32 - PMI_Min], 1}, + {&PartMappings[PMI_F32 - PMI_Min], 1}, + // up to 3 operands as F64 + {&PartMappings[PMI_F64 - PMI_Min], 1}, + {&PartMappings[PMI_F64 - PMI_Min], 1}, + {&PartMappings[PMI_F64 - PMI_Min], 1}}; + +} // namespace WebAssembly +} // namespace llvm + +using namespace llvm; + +WebAssemblyRegisterBankInfo::WebAssemblyRegisterBankInfo( + const TargetRegisterInfo &TRI) {} + +const RegisterBankInfo::InstructionMapping & +WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { + + unsigned Opc = MI.getOpcode(); + const MachineFunction &MF = *MI.getParent()->getParent(); + const MachineRegisterInfo &MRI = MF.getRegInfo(); + const TargetSubtargetInfo &STI = MF.getSubtarget(); + const TargetRegisterInfo &TRI = *STI.getRegisterInfo(); + + if ((Opc != TargetOpcode::COPY && !isPreISelGenericOpcode(Opc)) || + Opc == TargetOpcode::G_PHI) { + const RegisterBankInfo::InstructionMapping &Mapping = + getInstrMappingImpl(MI); + if (Mapping.isValid()) + return Mapping; + } + + using namespace TargetOpcode; + + unsigned NumOperands = MI.getNumOperands(); + const ValueMapping *OperandsMapping = nullptr; + unsigned MappingID = DefaultMappingID; + + // Check if LLT sizes match sizes of available register banks. + for (const MachineOperand &Op : MI.operands()) { + if (Op.isReg()) { + LLT RegTy = MRI.getType(Op.getReg()); + + if (RegTy.isScalar() && + (RegTy.getSizeInBits() != 32 && RegTy.getSizeInBits() != 64)) + return getInvalidInstructionMapping(); + + if (RegTy.isVector() && RegTy.getSizeInBits() != 128) + return getInvalidInstructionMapping(); + } + } + switch (Opc) { + case G_BR: + return getInstructionMapping(MappingID, /*Cost=*/1, + getOperandsMapping({nullptr}), NumOperands); + case G_TRAP: + case G_DEBUGTRAP: + return getInstructionMapping(MappingID, /*Cost=*/1, getOperandsMapping({}), + 0); + case COPY: + Register DstReg = MI.getOperand(0).getReg(); + if (DstReg.isPhysical()) { + if (DstReg.id() == WebAssembly::SP32) { + return getInstructionMapping( + MappingID, /*Cost=*/1, + getOperandsMapping( + {&WebAssembly::ValueMappings[WebAssembly::I32Idx]}), + 1); + } else if (DstReg.id() == WebAssembly::SP64) { + return getInstructionMapping( + MappingID, /*Cost=*/1, + getOperandsMapping( + {&WebAssembly::ValueMappings[WebAssembly::I64Idx]}), + 1); + } else { + llvm_unreachable("Trying to copy into WASM physical register other " + "than sp32 or sp64?"); + } + } + break; + } + + const LLT Op0Ty = MRI.getType(MI.getOperand(0).getReg()); + unsigned Op0Size = Op0Ty.getSizeInBits(); + + auto &Op0IntValueMapping = + WebAssembly::ValueMappings[Op0Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + auto &Op0FloatValueMapping = + WebAssembly::ValueMappings[Op0Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + auto &Pointer0ValueMapping = + WebAssembly::ValueMappings[MI.getMF()->getDataLayout() + .getPointerSizeInBits(0) == 64 + ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + switch (Opc) { + case G_AND: + case G_OR: + case G_XOR: + case G_SHL: + case G_ASHR: + case G_LSHR: + case G_PTR_ADD: + case G_INTTOPTR: + case G_PTRTOINT: + case G_ADD: + case G_SUB: + case G_MUL: + case G_SDIV: + case G_SREM: + case G_UDIV: + case G_UREM: + case G_CTLZ: + case G_CTLZ_ZERO_UNDEF: + case G_CTTZ: + case G_CTTZ_ZERO_UNDEF: + case G_CTPOP: + case G_FSHL: + case G_FSHR: + OperandsMapping = &Op0IntValueMapping; + break; + case G_FADD: + case G_FSUB: + case G_FDIV: + case G_FMUL: + case G_FNEG: + case G_FABS: + case G_FCEIL: + case G_FFLOOR: + case G_FSQRT: + case G_INTRINSIC_TRUNC: + case G_FNEARBYINT: + case G_FRINT: + case G_INTRINSIC_ROUNDEVEN: + case G_FMINIMUM: + case G_FMAXIMUM: + case G_FMINNUM: + case G_FMAXNUM: + case G_FMINNUM_IEEE: + case G_FMAXNUM_IEEE: + case G_FMA: + case G_FREM: + case G_FCOPYSIGN: + OperandsMapping = &Op0FloatValueMapping; + break; + case G_SEXT_INREG: + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Op0IntValueMapping, nullptr}); + break; + case G_FRAME_INDEX: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case G_VASTART: + OperandsMapping = &Op0IntValueMapping; + break; + case G_ZEXT: + case G_ANYEXT: + case G_SEXT: + case G_TRUNC: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1IntValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Op1IntValueMapping}); + break; + } + case G_LOAD: + case G_STORE: + if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0) + break; + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Pointer0ValueMapping}); + break; + case G_MEMCPY: + case G_MEMMOVE: { + if (MRI.getType(MI.getOperand(0).getReg()).getAddressSpace() != 0) + break; + if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0) + break; + + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op2Ty.getSizeInBits(); + auto &Op2IntValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + OperandsMapping = + getOperandsMapping({&Pointer0ValueMapping, &Pointer0ValueMapping, + &Op2IntValueMapping, nullptr}); + break; + } + case G_MEMSET: { + if (MRI.getType(MI.getOperand(0).getReg()).getAddressSpace() != 0) + break; + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + auto &Op1IntValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op2Ty.getSizeInBits(); + auto &Op2IntValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + OperandsMapping = + getOperandsMapping({&Pointer0ValueMapping, &Op1IntValueMapping, + &Op2IntValueMapping, nullptr}); + break; + } + case G_GLOBAL_VALUE: + case G_CONSTANT: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case G_FCONSTANT: + OperandsMapping = getOperandsMapping({&Op0FloatValueMapping, nullptr}); + break; + case G_IMPLICIT_DEF: + OperandsMapping = &Op0IntValueMapping; + break; + case G_ICMP: { + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op2Ty.getSizeInBits(); + + auto &Op2IntValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, nullptr, &Op2IntValueMapping, + &Op2IntValueMapping}); + break; + } + case G_FCMP: { + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op2Ty.getSizeInBits(); + + auto &Op2FloatValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, nullptr, &Op2FloatValueMapping, + &Op2FloatValueMapping}); + break; + } + case G_BRCOND: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case G_JUMP_TABLE: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case G_BRJT: + OperandsMapping = getOperandsMapping( + {&Op0IntValueMapping, nullptr, &Pointer0ValueMapping}); + break; + case COPY: { + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + + const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI); + const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI); + + if (!DstRB) + DstRB = SrcRB; + else if (!SrcRB) + SrcRB = DstRB; + + assert(DstRB && SrcRB && "Both RegBank were nullptr"); + TypeSize DstSize = getSizeInBits(DstReg, MRI, TRI); + TypeSize SrcSize = getSizeInBits(SrcReg, MRI, TRI); + assert(DstSize == SrcSize && + "Trying to copy between different sized regbanks? Why?"); + + WebAssembly::ValueMappingIdx DstValMappingIdx = WebAssembly::InvalidIdx; + switch (DstRB->getID()) { + case WebAssembly::I32RegBankID: + DstValMappingIdx = WebAssembly::I32Idx; + break; + case WebAssembly::I64RegBankID: + DstValMappingIdx = WebAssembly::I64Idx; + break; + case WebAssembly::F32RegBankID: + DstValMappingIdx = WebAssembly::F32Idx; + break; + case WebAssembly::F64RegBankID: + DstValMappingIdx = WebAssembly::F64Idx; + break; + default: + break; + } + + WebAssembly::ValueMappingIdx SrcValMappingIdx = WebAssembly::InvalidIdx; + switch (SrcRB->getID()) { + case WebAssembly::I32RegBankID: + SrcValMappingIdx = WebAssembly::I32Idx; + break; + case WebAssembly::I64RegBankID: + SrcValMappingIdx = WebAssembly::I64Idx; + break; + case WebAssembly::F32RegBankID: + SrcValMappingIdx = WebAssembly::F32Idx; + break; + case WebAssembly::F64RegBankID: + SrcValMappingIdx = WebAssembly::F64Idx; + break; + default: + break; + } + + OperandsMapping = + getOperandsMapping({&WebAssembly::ValueMappings[DstValMappingIdx], + &WebAssembly::ValueMappings[SrcValMappingIdx]}); + return getInstructionMapping( + MappingID, /*Cost=*/copyCost(*DstRB, *SrcRB, DstSize), OperandsMapping, + // We only care about the mapping of the destination for COPY. + 1); + } + case G_SELECT: + OperandsMapping = getOperandsMapping( + {&Op0IntValueMapping, &WebAssembly::ValueMappings[WebAssembly::I32Idx], + &Op0IntValueMapping, &Op0IntValueMapping}); + break; + case G_FPTOSI: + case G_FPTOSI_SAT: + case G_FPTOUI: + case G_FPTOUI_SAT: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1FloatValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Op1FloatValueMapping}); + break; + } + case G_SITOFP: + case G_UITOFP: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1IntValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0FloatValueMapping, &Op1IntValueMapping}); + break; + } + case G_FPEXT: + case G_FPTRUNC: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1FloatValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0FloatValueMapping, &Op1FloatValueMapping}); + break; + } + } + + if (!OperandsMapping) + return getInvalidInstructionMapping(); + + return getInstructionMapping(MappingID, /*Cost=*/1, OperandsMapping, + NumOperands); +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h new file mode 100644 index 0000000000000..f0d95b56ef861 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h @@ -0,0 +1,40 @@ +//===- WebAssemblyRegisterBankInfo.h ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file declares the targeting of the RegisterBankInfo class for WASM. +/// \todo This should be generated by TableGen. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H +#define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H + +#include "llvm/CodeGen/RegisterBankInfo.h" + +#define GET_REGBANK_DECLARATIONS +#include "WebAssemblyGenRegisterBank.inc" + +namespace llvm { + +class TargetRegisterInfo; + +class WebAssemblyGenRegisterBankInfo : public RegisterBankInfo { +#define GET_TARGET_REGBANK_CLASS +#include "WebAssemblyGenRegisterBank.inc" +}; + +/// This class provides the information for the target register banks. +class WebAssemblyRegisterBankInfo final + : public WebAssemblyGenRegisterBankInfo { +public: + WebAssemblyRegisterBankInfo(const TargetRegisterInfo &TRI); + + const InstructionMapping & + getInstrMapping(const MachineInstr &MI) const override; +}; +} // end namespace llvm +#endif diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h index 2dbd597f01cc9..0c56c5077c563 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -15,6 +15,9 @@ #ifndef LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLY_H #define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLY_H +#include "GISel/WebAssemblyRegisterBankInfo.h" +#include "WebAssemblySubtarget.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" #include "llvm/PassRegistry.h" #include "llvm/Support/CodeGen.h" @@ -32,6 +35,12 @@ FunctionPass *createWebAssemblyOptimizeReturned(); FunctionPass *createWebAssemblyLowerRefTypesIntPtrConv(); FunctionPass *createWebAssemblyRefTypeMem2Local(); +// GlobalISel +InstructionSelector * +createWebAssemblyInstructionSelector(const WebAssemblyTargetMachine &, + const WebAssemblySubtarget &, + const WebAssemblyRegisterBankInfo &); + // ISel and immediate followup passes. FunctionPass *createWebAssemblyISelDag(WebAssemblyTargetMachine &TM, CodeGenOptLevel OptLevel); diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.td b/llvm/lib/Target/WebAssembly/WebAssembly.td index 089be5f1dc70e..3705a42fd21c9 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.td +++ b/llvm/lib/Target/WebAssembly/WebAssembly.td @@ -101,6 +101,7 @@ def FeatureWideArithmetic : //===----------------------------------------------------------------------===// include "WebAssemblyRegisterInfo.td" +include "WebAssemblyRegisterBanks.td" //===----------------------------------------------------------------------===// // Instruction Descriptions diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td new file mode 100644 index 0000000000000..5ed2dede7a080 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td @@ -0,0 +1,133 @@ +//===-- WebAssemblyGIsel.td - WASM GlobalISel Patterns -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file contains patterns that are relevant to GlobalISel, including +/// GIComplexOperandMatcher definitions for equivalent SelectionDAG +/// ComplexPatterns. +// +//===----------------------------------------------------------------------===// + +include "WebAssembly.td" + + +//===----------------------------------------------------------------------===// +// Pointer types and related patterns +//===----------------------------------------------------------------------===// + +defvar ModeAddr32 = DefaultMode; +def ModeAddr64 : HwMode<[HasAddr64]>; + +def Addr0VT : ValueTypeByHwMode<[ModeAddr32, ModeAddr64], + [i32, i64]>; + +def p0 : PtrValueTypeByHwMode; + +// G_CONSTANT with p0 +def : Pat<(p0 (imm:$addr)), + (CONST_I32 imm:$addr)>, Requires<[HasAddr32]>; +def : Pat<(p0 (imm:$addr)), + (CONST_I64 imm:$addr)>, Requires<[HasAddr64]>; + +// G_LOAD of p0 +def : Pat<(p0 (load (AddrOps32 offset32_op:$offset, I32:$addr))), + (LOAD_I32_A32 0, + offset32_op:$offset, + I32:$addr)>, + Requires<[HasAddr32]>; + +def : Pat<(p0 (load (AddrOps64 offset64_op:$offset, I64:$addr))), + (LOAD_I64_A64 0, + offset64_op:$offset, + I64:$addr)>, + Requires<[HasAddr64]>; + +// G_STORE of p0 +def : Pat<(store p0:$val, (AddrOps32 offset32_op:$offset, I32:$addr)), + (STORE_I32_A32 0, + offset32_op:$offset, + I32:$addr, + p0:$val)>, + Requires<[HasAddr32]>; + +def : Pat<(store p0:$val, (AddrOps64 offset64_op:$offset, I64:$addr)), + (STORE_I64_A64 0, + offset64_op:$offset, + I64:$addr, + p0:$val)>, + Requires<[HasAddr64]>; + +// G_SELECT of p0 +def : Pat<(select I32:$cond, p0:$lhs, p0:$rhs), + (SELECT_I32 I32:$lhs, I32:$rhs, I32:$cond)>, Requires<[HasAddr32]>; +def : Pat<(select I32:$cond, p0:$lhs, p0:$rhs), + (SELECT_I64 I64:$lhs, I64:$rhs, I32:$cond)>, Requires<[HasAddr64]>; + +// ISD::SELECT requires its operand to conform to getBooleanContents, but +// WebAssembly's select interprets any non-zero value as true, so we can fold +// a setne with 0 into a select. +def : Pat<(select (i32 (setne I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I32 I32:$lhs, I32:$rhs, I32:$cond)>, Requires<[HasAddr32]>; +def : Pat<(select (i32 (setne I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I64 I64:$lhs, I64:$rhs, I32:$cond)>, Requires<[HasAddr64]>; + +// And again, this time with seteq instead of setne and the arms reversed. +def : Pat<(select (i32 (seteq I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I32 I32:$rhs, I32:$lhs, I32:$cond)>, Requires<[HasAddr32]>; +def : Pat<(select (i32 (seteq I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I64 I64:$rhs, I64:$lhs, I32:$cond)>, Requires<[HasAddr64]>; + + +// G_ICMP between p0 +multiclass ComparisonP0 { + def : Pat<(setcc p0:$lhs, p0:$rhs, cond), + (!cast(Name # "_I32") I32:$lhs, I32:$rhs)>, Requires<[HasAddr32]>; + def : Pat<(setcc p0:$lhs, p0:$rhs, cond), + (!cast(Name # "_I64") I64:$lhs, I64:$rhs)>, Requires<[HasAddr64]>; +} + +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; + +//===----------------------------------------------------------------------===// +// Miscallenous patterns +//===----------------------------------------------------------------------===// + +def : Pat<(i32 (fp_to_sint_sat_gi F32:$src)), (I32_TRUNC_S_SAT_F32 F32:$src)>; +def : Pat<(i32 (fp_to_uint_sat_gi F32:$src)), (I32_TRUNC_U_SAT_F32 F32:$src)>; +def : Pat<(i32 (fp_to_sint_sat_gi F64:$src)), (I32_TRUNC_S_SAT_F64 F64:$src)>; +def : Pat<(i32 (fp_to_uint_sat_gi F64:$src)), (I32_TRUNC_U_SAT_F64 F64:$src)>; +def : Pat<(i64 (fp_to_sint_sat_gi F32:$src)), (I64_TRUNC_S_SAT_F32 F32:$src)>; +def : Pat<(i64 (fp_to_uint_sat_gi F32:$src)), (I64_TRUNC_U_SAT_F32 F32:$src)>; +def : Pat<(i64 (fp_to_sint_sat_gi F64:$src)), (I64_TRUNC_S_SAT_F64 F64:$src)>; +def : Pat<(i64 (fp_to_uint_sat_gi F64:$src)), (I64_TRUNC_U_SAT_F64 F64:$src)>; + +def : GINodeEquiv; + +def : Pat<(i32 (ctlz_zero_undef I32:$src)), (CLZ_I32 I32:$src)>; +def : Pat<(i64 (ctlz_zero_undef I64:$src)), (CLZ_I64 I64:$src)>; +def : Pat<(i32 (cttz_zero_undef I32:$src)), (CTZ_I32 I32:$src)>; +def : Pat<(i64 (cttz_zero_undef I64:$src)), (CTZ_I64 I64:$src)>; + +//===----------------------------------------------------------------------===// +// Complex pattern equivalents +//===----------------------------------------------------------------------===// + +def gi_AddrOps32 : GIComplexOperandMatcher, + GIComplexPatternEquiv; + +def gi_AddrOps64 : GIComplexOperandMatcher, + GIComplexPatternEquiv; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td new file mode 100644 index 0000000000000..9ebece0e0bf09 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td @@ -0,0 +1,20 @@ +//=- WebAssemblyRegisterBank.td - Describe the WASM Banks ----*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + + +def I32RegBank : RegisterBank<"I32RegBank", [I32]>; +def I64RegBank : RegisterBank<"I64RegBank", [I64]>; +def F32RegBank : RegisterBank<"F64RegBank", [F32]>; +def F64RegBank : RegisterBank<"F64RegBank", [F64]>; + +def EXTERNREFRegBank : RegisterBank<"EXTERNREFRegBank", [EXTERNREF]>; +def FUNCREFRegBank : RegisterBank<"FUNCREFRegBank", [FUNCREF]>; +def EXNREFRegBank : RegisterBank<"EXNREFRegBank", [EXNREF]>; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp index a3ce40f0297ec..315cbb65371a0 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp @@ -12,9 +12,14 @@ /// //===----------------------------------------------------------------------===// +#include "WebAssembly.h" #include "WebAssemblySubtarget.h" +#include "GISel/WebAssemblyCallLowering.h" +#include "GISel/WebAssemblyLegalizerInfo.h" +#include "GISel/WebAssemblyRegisterBankInfo.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" #include "WebAssemblyInstrInfo.h" +#include "WebAssemblyTargetMachine.h" #include "llvm/MC/TargetRegistry.h" using namespace llvm; @@ -66,7 +71,15 @@ WebAssemblySubtarget::WebAssemblySubtarget(const Triple &TT, const TargetMachine &TM) : WebAssemblyGenSubtargetInfo(TT, CPU, /*TuneCPU*/ CPU, FS), TargetTriple(TT), InstrInfo(initializeSubtargetDependencies(CPU, FS)), - TLInfo(TM, *this) {} + TLInfo(TM, *this) { + CallLoweringInfo.reset(new WebAssemblyCallLowering(*getTargetLowering())); + Legalizer.reset(new WebAssemblyLegalizerInfo(*this)); + auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo()); + RegBankInfo.reset(RBI); + + InstSelector.reset(createWebAssemblyInstructionSelector( + *static_cast(&TM), *this, *RBI)); +} bool WebAssemblySubtarget::enableAtomicExpand() const { // If atomics are disabled, atomic ops are lowered instead of expanded @@ -81,3 +94,19 @@ bool WebAssemblySubtarget::enableMachineScheduler() const { } bool WebAssemblySubtarget::useAA() const { return true; } + +const CallLowering *WebAssemblySubtarget::getCallLowering() const { + return CallLoweringInfo.get(); +} + +InstructionSelector *WebAssemblySubtarget::getInstructionSelector() const { + return InstSelector.get(); +} + +const LegalizerInfo *WebAssemblySubtarget::getLegalizerInfo() const { + return Legalizer.get(); +} + +const RegisterBankInfo *WebAssemblySubtarget::getRegBankInfo() const { + return RegBankInfo.get(); +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h index 2f88bbba05d00..c195f995009b1 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h @@ -20,6 +20,10 @@ #include "WebAssemblyISelLowering.h" #include "WebAssemblyInstrInfo.h" #include "WebAssemblySelectionDAGInfo.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" +#include "llvm/CodeGen/RegisterBankInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include @@ -64,6 +68,11 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo { WebAssemblySelectionDAGInfo TSInfo; WebAssemblyTargetLowering TLInfo; + std::unique_ptr CallLoweringInfo; + std::unique_ptr InstSelector; + std::unique_ptr Legalizer; + std::unique_ptr RegBankInfo; + WebAssemblySubtarget &initializeSubtargetDependencies(StringRef CPU, StringRef FS); @@ -118,6 +127,11 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo { /// Parses features string setting specified subtarget options. Definition of /// function is auto generated by tblgen. void ParseSubtargetFeatures(StringRef CPU, StringRef TuneCPU, StringRef FS); + + const CallLowering *getCallLowering() const override; + InstructionSelector *getInstructionSelector() const override; + const LegalizerInfo *getLegalizerInfo() const override; + const RegisterBankInfo *getRegBankInfo() const override; }; } // end namespace llvm diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp index a9c638cde1259..6a3e8148837fa 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -20,6 +20,10 @@ #include "WebAssemblyTargetObjectFile.h" #include "WebAssemblyTargetTransformInfo.h" #include "WebAssemblyUtilities.h" +#include "llvm/CodeGen/GlobalISel/IRTranslator.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelect.h" +#include "llvm/CodeGen/GlobalISel/Legalizer.h" +#include "llvm/CodeGen/GlobalISel/RegBankSelect.h" #include "llvm/CodeGen/MIRParser/MIParser.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/RegAllocRegistry.h" @@ -92,6 +96,7 @@ LLVMInitializeWebAssemblyTarget() { // Register backend passes auto &PR = *PassRegistry::getPassRegistry(); + initializeGlobalISel(PR); initializeWebAssemblyAddMissingPrototypesPass(PR); initializeWebAssemblyLowerEmscriptenEHSjLjPass(PR); initializeLowerGlobalDtorsLegacyPassPass(PR); @@ -445,6 +450,11 @@ class WebAssemblyPassConfig final : public TargetPassConfig { // No reg alloc bool addRegAssignAndRewriteOptimized() override { return false; } + + bool addIRTranslator() override; + bool addLegalizeMachineIR() override; + bool addRegBankSelect() override; + bool addGlobalInstructionSelect() override; }; } // end anonymous namespace @@ -665,6 +675,32 @@ bool WebAssemblyPassConfig::addPreISel() { return false; } +bool WebAssemblyPassConfig::addIRTranslator() { + addPass(new IRTranslator()); + return false; +} + +bool WebAssemblyPassConfig::addLegalizeMachineIR() { + addPass(new Legalizer()); + return false; +} + +bool WebAssemblyPassConfig::addRegBankSelect() { + addPass(new RegBankSelect()); + return false; +} + +bool WebAssemblyPassConfig::addGlobalInstructionSelect() { + addPass(new InstructionSelect(getOptLevel())); + + addPass(createWebAssemblyArgumentMove()); + addPass(createWebAssemblySetP2AlignOperands()); + addPass(createWebAssemblyFixBrTableDefaults()); + addPass(createWebAssemblyCleanCodeAfterTrap()); + + return false; +} + yaml::MachineFunctionInfo * WebAssemblyTargetMachine::createDefaultFuncInfoYAML() const { return new yaml::WebAssemblyFunctionInfo();