-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[WebAssembly] Implement lowering calls through funcref to call_ref when available #162227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-backend-webassembly Author: Demetrius Kanios (QuantumSegfault) ChangesAllows calls through Builds upon the framework provided by #147486 Example Source IR define void @<!-- -->call_ref_void(%funcref %callee) {
call addrspace(20) void %callee()
ret void
} Result Before this PR and/or without GC: i32.const 0
local.get 0
table.set __funcref_call_table
i32.const 0
call_indirect __funcref_call_table, () -> ()
i32.const 0
ref.null_func
table.set __funcref_call_table After this PR, when compiled with local.get 0
ref.cast () -> ()
call_ref () -> () Full diff: https://github.com/llvm/llvm-project/pull/162227.diff 8 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
index fe9a4bada2430..db4d9edb152ce 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
@@ -435,6 +435,18 @@ inline bool isCallIndirect(unsigned Opc) {
}
}
+inline bool isCallRef(unsigned Opc) {
+ switch (Opc) {
+ case WebAssembly::CALL_REF:
+ case WebAssembly::CALL_REF_S:
+ case WebAssembly::RET_CALL_REF:
+ case WebAssembly::RET_CALL_REF_S:
+ return true;
+ default:
+ return false;
+ }
+}
+
inline bool isBrTable(unsigned Opc) {
switch (Opc) {
case WebAssembly::BR_TABLE_I32:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
index 2541b0433ab59..03c90c7160a68 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
@@ -120,60 +120,6 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) {
return DAG->getTargetExternalSymbol(SymName, PtrVT);
}
-static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
- SmallVector<MVT, 4> &Returns,
- SmallVector<MVT, 4> &Params) {
- auto toWasmValType = [](MVT VT) {
- if (VT == MVT::i32) {
- return wasm::ValType::I32;
- }
- if (VT == MVT::i64) {
- return wasm::ValType::I64;
- }
- if (VT == MVT::f32) {
- return wasm::ValType::F32;
- }
- if (VT == MVT::f64) {
- return wasm::ValType::F64;
- }
- if (VT == MVT::externref) {
- return wasm::ValType::EXTERNREF;
- }
- if (VT == MVT::funcref) {
- return wasm::ValType::FUNCREF;
- }
- if (VT == MVT::exnref) {
- return wasm::ValType::EXNREF;
- }
- LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT
- << "\n");
- llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func");
- };
- auto NParams = Params.size();
- auto NReturns = Returns.size();
- auto BitWidth = (NParams + NReturns + 2) * 64;
- auto Sig = APInt(BitWidth, 0);
-
- // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
- // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
- // always emit a CImm. So xor NParams with 0x7ffffff to ensure
- // getSignificantBits() > 64
- Sig |= NReturns ^ 0x7ffffff;
- for (auto &Return : Returns) {
- auto V = toWasmValType(Return);
- Sig <<= 64;
- Sig |= (int64_t)V;
- }
- Sig <<= 64;
- Sig |= NParams;
- for (auto &Param : Params) {
- auto V = toWasmValType(Param);
- Sig <<= 64;
- Sig |= (int64_t)V;
- }
- return Sig;
-}
-
void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
// If we have a custom node, we already have selected!
if (Node->isMachineOpcode()) {
@@ -288,7 +234,8 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
Returns.push_back(VT);
}
}
- auto Sig = encodeFunctionSignature(CurDAG, DL, Returns, Params);
+ auto Sig =
+ WebAssembly::encodeFunctionSignature(CurDAG, DL, Returns, Params);
auto SigOp = CurDAG->getTargetConstant(
Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth()));
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 163bf9ba5b089..bd0733c73f7ed 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -723,6 +723,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
bool IsIndirect =
CallParams.getOperand(0).isReg() || CallParams.getOperand(0).isFI();
bool IsRetCall = CallResults.getOpcode() == WebAssembly::RET_CALL_RESULTS;
+ bool IsCallRef = false;
bool IsFuncrefCall = false;
if (IsIndirect && CallParams.getOperand(0).isReg()) {
@@ -732,10 +733,19 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
const TargetRegisterClass *TRC = MRI.getRegClass(Reg);
IsFuncrefCall = (TRC == &WebAssembly::FUNCREFRegClass);
assert(!IsFuncrefCall || Subtarget->hasReferenceTypes());
+
+ if (IsFuncrefCall && Subtarget->hasGC()) {
+ IsIndirect = false;
+ IsCallRef = true;
+ }
}
unsigned CallOp;
- if (IsIndirect && IsRetCall) {
+ if (IsCallRef && IsRetCall) {
+ CallOp = WebAssembly::RET_CALL_REF;
+ } else if (IsCallRef) {
+ CallOp = WebAssembly::CALL_REF;
+ } else if (IsIndirect && IsRetCall) {
CallOp = WebAssembly::RET_CALL_INDIRECT;
} else if (IsIndirect) {
CallOp = WebAssembly::CALL_INDIRECT;
@@ -771,6 +781,14 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
CallParams.addOperand(FnPtr);
}
+ // Move the function pointer to the end of the arguments for funcref calls
+ if (IsCallRef) {
+ auto FnRef = CallParams.getOperand(0);
+ CallParams.removeOperand(0);
+
+ CallParams.addOperand(FnRef);
+ }
+
for (auto Def : CallResults.defs())
MIB.add(Def);
@@ -795,6 +813,12 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
}
}
+ if (IsCallRef) {
+ // Placeholder for the type index.
+ // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
+ MIB.addImm(0);
+ }
+
for (auto Use : CallParams.uses())
MIB.add(Use);
@@ -1173,6 +1197,60 @@ static bool callingConvSupported(CallingConv::ID CallConv) {
CallConv == CallingConv::Swift;
}
+APInt WebAssembly::encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
+ SmallVector<MVT, 4> &Returns,
+ SmallVector<MVT, 4> &Params) {
+ auto toWasmValType = [](MVT VT) {
+ if (VT == MVT::i32) {
+ return wasm::ValType::I32;
+ }
+ if (VT == MVT::i64) {
+ return wasm::ValType::I64;
+ }
+ if (VT == MVT::f32) {
+ return wasm::ValType::F32;
+ }
+ if (VT == MVT::f64) {
+ return wasm::ValType::F64;
+ }
+ if (VT == MVT::externref) {
+ return wasm::ValType::EXTERNREF;
+ }
+ if (VT == MVT::funcref) {
+ return wasm::ValType::FUNCREF;
+ }
+ if (VT == MVT::exnref) {
+ return wasm::ValType::EXNREF;
+ }
+ LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT
+ << "\n");
+ llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func");
+ };
+ auto NParams = Params.size();
+ auto NReturns = Returns.size();
+ auto BitWidth = (NParams + NReturns + 2) * 64;
+ auto Sig = APInt(BitWidth, 0);
+
+ // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
+ // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
+ // always emit a CImm. So xor NParams with 0x7ffffff to ensure
+ // getSignificantBits() > 64
+ Sig |= NReturns ^ 0x7ffffff;
+ for (auto &Return : Returns) {
+ auto V = toWasmValType(Return);
+ Sig <<= 64;
+ Sig |= (int64_t)V;
+ }
+ Sig <<= 64;
+ Sig |= NParams;
+ for (auto &Param : Params) {
+ auto V = toWasmValType(Param);
+ Sig <<= 64;
+ Sig |= (int64_t)V;
+ }
+ return Sig;
+}
+
SDValue
WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {
@@ -1412,33 +1490,58 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
InTys.push_back(In.VT);
}
- // Lastly, if this is a call to a funcref we need to add an instruction
- // table.set to the chain and transform the call.
+ // Lastly, if this is a call to a funcref we need to insert an instruction
+ // to either cast the funcref to a typed funcref for call_ref, or place it
+ // into a table for call_indirect
if (CLI.CB && WebAssembly::isWebAssemblyFuncrefType(
CLI.CB->getCalledOperand()->getType())) {
- // In the absence of function references proposal where a funcref call is
- // lowered to call_ref, using reference types we generate a table.set to set
- // the funcref to a special table used solely for this purpose, followed by
- // a call_indirect. Here we just generate the table set, and return the
- // SDValue of the table.set so that LowerCall can finalize the lowering by
- // generating the call_indirect.
- SDValue Chain = Ops[0];
+ if (Subtarget->hasGC()) {
+ // Since LLVM doesn't directly support typed function references, we take
+ // the untyped funcref and ref.cast it into a typed funcref.
+ SmallVector<MVT, 4> Params;
+ SmallVector<MVT, 4> Returns;
+
+ for (const auto &Out : Outs) {
+ Params.push_back(Out.VT);
+ }
+ for (const auto &In : Ins) {
+ Returns.push_back(In.VT);
+ }
- MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
- MF.getContext(), Subtarget);
- SDValue Sym = DAG.getMCSymbol(Table, PtrVT);
- SDValue TableSlot = DAG.getConstant(0, DL, MVT::i32);
- SDValue TableSetOps[] = {Chain, Sym, TableSlot, Callee};
- SDValue TableSet = DAG.getMemIntrinsicNode(
- WebAssemblyISD::TABLE_SET, DL, DAG.getVTList(MVT::Other), TableSetOps,
- MVT::funcref,
- // Machine Mem Operand args
- MachinePointerInfo(
- WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF),
- CLI.CB->getCalledOperand()->getPointerAlignment(DAG.getDataLayout()),
- MachineMemOperand::MOStore);
-
- Ops[0] = TableSet; // The new chain is the TableSet itself
+ auto Sig =
+ WebAssembly::encodeFunctionSignature(&DAG, DL, Returns, Params);
+
+ auto SigOp = DAG.getTargetConstant(
+ Sig, DL, EVT::getIntegerVT(*DAG.getContext(), Sig.getBitWidth()));
+ MachineSDNode *RefCastNode = DAG.getMachineNode(
+ WebAssembly::REF_CAST_FUNCREF, DL, MVT::funcref, {SigOp, Callee});
+
+ Ops[1] = SDValue(RefCastNode, 0);
+ } else {
+ // In the absence of function references proposal where a funcref call is
+ // lowered to call_ref, using reference types we generate a table.set to
+ // set the funcref to a special table used solely for this purpose,
+ // followed by a call_indirect. Here we just generate the table set, and
+ // return the SDValue of the table.set so that LowerCall can finalize the
+ // lowering by generating the call_indirect.
+ SDValue Chain = Ops[0];
+
+ MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
+ MF.getContext(), Subtarget);
+ SDValue Sym = DAG.getMCSymbol(Table, PtrVT);
+ SDValue TableSlot = DAG.getConstant(0, DL, MVT::i32);
+ SDValue TableSetOps[] = {Chain, Sym, TableSlot, Callee};
+ SDValue TableSet = DAG.getMemIntrinsicNode(
+ WebAssemblyISD::TABLE_SET, DL, DAG.getVTList(MVT::Other), TableSetOps,
+ MVT::funcref,
+ // Machine Mem Operand args
+ MachinePointerInfo(
+ WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF),
+ CLI.CB->getCalledOperand()->getPointerAlignment(DAG.getDataLayout()),
+ MachineMemOperand::MOStore);
+
+ Ops[0] = TableSet; // The new chain is the TableSet itself
+ }
}
if (CLI.IsTailCall) {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index b33a8530310be..7d2194132f293 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -141,6 +141,11 @@ class WebAssemblyTargetLowering final : public TargetLowering {
namespace WebAssembly {
FastISel *createFastISel(FunctionLoweringInfo &funcInfo,
const TargetLibraryInfo *libInfo);
+
+APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
+ SmallVector<MVT, 4> &Returns,
+ SmallVector<MVT, 4> &Params);
+
} // end namespace WebAssembly
} // end namespace llvm
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
index ca9a5ef9dda1c..81b62f6a682ec 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
@@ -66,6 +66,16 @@ defm CALL_INDIRECT :
[],
"call_indirect", "call_indirect\t$type, $table", 0x11>;
+let variadicOpsAreDefs = 1 in
+defm CALL_REF :
+ I<(outs),
+ (ins TypeIndex:$type, variable_ops),
+ (outs),
+ (ins TypeIndex:$type),
+ [],
+ "call_ref", "call_ref\t$type", 0x14>,
+ Requires<[HasGC]>;
+
let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in
defm RET_CALL :
I<(outs), (ins function32_op:$callee, variable_ops),
@@ -81,4 +91,14 @@ defm RET_CALL_INDIRECT :
0x13>,
Requires<[HasTailCall]>;
+let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in
+defm RET_CALL_REF :
+ I<(outs),
+ (ins TypeIndex:$type, variable_ops),
+ (outs),
+ (ins TypeIndex:$type),
+ [],
+ "return_call_ref", "return_call_ref\t$type", 0x15>,
+ Requires<[HasTailCall, HasGC]>;
+
} // Uses = [SP32,SP64], isCall = 1
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
index fc82e5b4a61da..6fa6ed897d647 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
@@ -41,6 +41,11 @@ defm REF_TEST_FUNCREF : I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref),
"ref.test\t$type, $ref", "ref.test $type", 0xfb14>,
Requires<[HasGC]>;
+defm REF_CAST_FUNCREF : I<(outs FUNCREF:$res), (ins TypeIndex:$type, FUNCREF:$ref),
+ (outs), (ins TypeIndex:$type), [],
+ "ref.cast\t$type, $ref", "ref.cast $type", 0xfb16>,
+ Requires<[HasGC]>;
+
defm "" : REF_I<FUNCREF, funcref, "func">;
defm "" : REF_I<EXTERNREF, externref, "extern">;
defm "" : REF_I<EXNREF, exnref, "exn">;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index e48283aadb437..1ed15967c01fe 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -230,7 +230,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
case llvm::MachineOperand::MO_CImmediate: {
- // Lower type index placeholder for ref.test
+ // Lower type index placeholder for ref.test and ref.cast
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
@@ -266,14 +266,16 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
Params.push_back(WebAssembly::regClassToValType(
MRI.getRegClass(MO.getReg())->getID()));
- // call_indirect instructions have a callee operand at the end which
- // doesn't count as a param.
- if (WebAssembly::isCallIndirect(MI->getOpcode()))
+ // call_indirect and call_ref instructions have a callee operand at
+ // the end which doesn't count as a param.
+ if (WebAssembly::isCallIndirect(MI->getOpcode()) ||
+ WebAssembly::isCallRef(MI->getOpcode()))
Params.pop_back();
- // return_call_indirect instructions have the return type of the
- // caller
- if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT)
+ // return_call_indirect and return_call_ref instructions have the
+ // return type of the caller
+ if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT ||
+ MI->getOpcode() == WebAssembly::RET_CALL_REF)
getFunctionReturns(MI, Returns);
MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
diff --git a/llvm/test/CodeGen/WebAssembly/call-ref.ll b/llvm/test/CodeGen/WebAssembly/call-ref.ll
new file mode 100644
index 0000000000000..25fc7440ac64c
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/call-ref.ll
@@ -0,0 +1,51 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mattr=+reference-types,-gc | FileCheck --check-prefixes=CHECK,NOGC %s
+; RUN: llc < %s -mattr=+reference-types,+gc | FileCheck --check-prefixes=CHECK,GC %s
+
+; Test that calls through funcref lower to call_ref when GC is available
+
+target triple = "wasm32-unknown-unknown"
+
+%funcref = type ptr addrspace(20);
+
+define void @call_ref_void(%funcref %callee) {
+; CHECK-LABEL: call_ref_void:
+; CHECK: .functype call_ref_void (funcref) -> ()
+; CHECK-NEXT: # %bb.0:
+; NOGC-NEXT: i32.const 0
+; CHECK-NEXT: local.get 0
+; NOGC-NEXT: table.set __funcref_call_table
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: call_indirect __funcref_call_table, () -> ()
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: ref.null_func
+; NOGC-NEXT: table.set __funcref_call_table
+; GC-NEXT: ref.cast () -> ()
+; GC-NEXT: call_ref () -> ()
+; CHECK-NEXT: # fallthrough-return
+ call addrspace(20) void %callee()
+ ret void
+}
+
+define void @call_ref_with_args_and_ret(%funcref %callee) {
+; CHECK-LABEL: call_ref_with_args_and_ret:
+; CHECK: .functype call_ref_with_args_and_ret (funcref) -> ()
+; CHECK-NEXT: # %bb.0:
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: local.get 0
+; NOGC-NEXT: table.set __funcref_call_table
+; CHECK-NEXT: i32.const 1
+; CHECK-NEXT: f64.const 0x1p1
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: call_indirect __funcref_call_table, (i32, f64) -> (i32)
+; GC-NEXT: local.get 0
+; GC-NEXT: ref.cast (i32, f64) -> (i32)
+; GC-NEXT: call_ref (i32, f64) -> (i32)
+; CHECK-NEXT: drop
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: ref.null_func
+; NOGC-NEXT: table.set __funcref_call_table
+; CHECK-NEXT: # fallthrough-return
+ %result = call addrspace(20) i32 %callee(i32 1, double 2.0)
+ ret void
+}
|
Allows calls through
funcref
(ptr addrspace(20)
) to be lowered to a sequence ofref.cast
+call_ref
when WasmGC is available. This is opposed to the current work around of storing the funcref into a special table, and usingcall_indirect
.Builds upon the framework provided by #147486
Example
Source IR
Result
Before this PR and/or without GC:
After this PR, when compiled with
-mattr=+gc
: