Skip to content

Commit 56b4653

Browse files
hoodmanemahesh-attarde
authored andcommitted
[WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic (llvm#147486)
This adds an llvm intrinsic for WebAssembly to test the type of a function. It is intended for adding a future clang builtin ` __builtin_wasm_test_function_pointer_signature` so we can test whether calling a function pointer will fail with function signature mismatch. Since the type of a function pointer is just `ptr` we can't figure out the expected type from that. The way I figured out to encode the type was by passing 0's of the appropriate type to the intrinsic. The first argument gives the expected type of the return type and the later values give the expected type of the arguments. So ```llvm @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, double 0.000000e+00, i32 0) ``` tests if `%func` is of type `(double, i32) -> (i32)`. It will lower to: ```wat local.get $func table.get $__indirect_function_table ref.test (double, i32) -> (i32) ``` To indicate the function should be void, I somewhat arbitrarily picked `token poison`, so the following tests for `(i32) -> ()`: ```llvm @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0) ``` To lower this intrinsic, we need some place to put the type information. With `encodeFunctionSignature()` we encode the signature information into an `APInt`. We decode it in `lowerEncodedFunctionSignature` in `WebAssemblyMCInstLower.cpp`.
1 parent 2ea8e70 commit 56b4653

File tree

7 files changed

+276
-2
lines changed

7 files changed

+276
-2
lines changed

llvm/include/llvm/IR/IntrinsicsWebAssembly.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
4343
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
4444
"llvm.wasm.ref.is_null.exn">;
4545

46+
def int_wasm_ref_test_func
47+
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
48+
[IntrNoMem]>;
49+
4650
//===----------------------------------------------------------------------===//
4751
// Table intrinsics
4852
//===----------------------------------------------------------------------===//

llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
402402
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
403403
IsDebug, IsClone, IsCloned);
404404
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
405-
MIB.addImm(C->getSExtValue());
405+
if (C->getAPIntValue().getSignificantBits() <= 64) {
406+
MIB.addImm(C->getSExtValue());
407+
} else {
408+
MIB.addCImm(
409+
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
410+
}
406411
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
407412
MIB.addFPImm(F->getConstantFPValue());
408413
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {

llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
#include "WebAssembly.h"
1616
#include "WebAssemblyISelLowering.h"
1717
#include "WebAssemblyTargetMachine.h"
18+
#include "WebAssemblyUtilities.h"
1819
#include "llvm/CodeGen/MachineFrameInfo.h"
1920
#include "llvm/CodeGen/SelectionDAGISel.h"
2021
#include "llvm/CodeGen/WasmEHFuncInfo.h"
2122
#include "llvm/IR/DiagnosticInfo.h"
2223
#include "llvm/IR/Function.h" // To access function attributes.
2324
#include "llvm/IR/IntrinsicsWebAssembly.h"
25+
#include "llvm/MC/MCSymbolWasm.h"
2426
#include "llvm/Support/Debug.h"
2527
#include "llvm/Support/KnownBits.h"
2628
#include "llvm/Support/raw_ostream.h"
@@ -118,6 +120,51 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) {
118120
return DAG->getTargetExternalSymbol(SymName, PtrVT);
119121
}
120122

123+
static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
124+
SmallVector<MVT, 4> &Returns,
125+
SmallVector<MVT, 4> &Params) {
126+
auto toWasmValType = [&DAG, &DL](MVT VT) {
127+
if (VT == MVT::i32) {
128+
return wasm::ValType::I32;
129+
}
130+
if (VT == MVT::i64) {
131+
return wasm::ValType::I64;
132+
}
133+
if (VT == MVT::f32) {
134+
return wasm::ValType::F32;
135+
}
136+
if (VT == MVT::f64) {
137+
return wasm::ValType::F64;
138+
}
139+
LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT
140+
<< "\n");
141+
llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func");
142+
};
143+
auto NParams = Params.size();
144+
auto NReturns = Returns.size();
145+
auto BitWidth = (NParams + NReturns + 2) * 64;
146+
auto Sig = APInt(BitWidth, 0);
147+
148+
// Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
149+
// emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
150+
// always emit a CImm. So xor NParams with 0x7ffffff to ensure
151+
// getSignificantBits() > 64
152+
Sig |= NReturns ^ 0x7ffffff;
153+
for (auto &Return : Returns) {
154+
auto V = toWasmValType(Return);
155+
Sig <<= 64;
156+
Sig |= (int64_t)V;
157+
}
158+
Sig <<= 64;
159+
Sig |= NParams;
160+
for (auto &Param : Params) {
161+
auto V = toWasmValType(Param);
162+
Sig <<= 64;
163+
Sig |= (int64_t)V;
164+
}
165+
return Sig;
166+
}
167+
121168
void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
122169
// If we have a custom node, we already have selected!
123170
if (Node->isMachineOpcode()) {
@@ -189,6 +236,50 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
189236
ReplaceNode(Node, TLSAlign);
190237
return;
191238
}
239+
case Intrinsic::wasm_ref_test_func: {
240+
// First emit the TABLE_GET instruction to convert function pointer ==>
241+
// funcref
242+
MachineFunction &MF = CurDAG->getMachineFunction();
243+
auto PtrVT = MVT::getIntegerVT(MF.getDataLayout().getPointerSizeInBits());
244+
MCSymbol *Table = WebAssembly::getOrCreateFunctionTableSymbol(
245+
MF.getContext(), Subtarget);
246+
SDValue TableSym = CurDAG->getMCSymbol(Table, PtrVT);
247+
SDValue FuncRef = SDValue(
248+
CurDAG->getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
249+
MVT::funcref, TableSym, Node->getOperand(1)),
250+
0);
251+
252+
// Encode the signature information into the type index placeholder.
253+
// This gets decoded and converted into the actual type signature in
254+
// WebAssemblyMCInstLower.cpp.
255+
SmallVector<MVT, 4> Params;
256+
SmallVector<MVT, 4> Returns;
257+
258+
bool IsParam = false;
259+
// Operand 0 is the return register, Operand 1 is the function pointer.
260+
// The remaining operands encode the type of the function we are testing
261+
// for.
262+
for (unsigned I = 2, E = Node->getNumOperands(); I < E; ++I) {
263+
MVT VT = Node->getOperand(I).getValueType().getSimpleVT();
264+
if (VT == MVT::Untyped) {
265+
IsParam = true;
266+
continue;
267+
}
268+
if (IsParam) {
269+
Params.push_back(VT);
270+
} else {
271+
Returns.push_back(VT);
272+
}
273+
}
274+
auto Sig = encodeFunctionSignature(CurDAG, DL, Returns, Params);
275+
276+
auto SigOp = CurDAG->getTargetConstant(
277+
Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth()));
278+
MachineSDNode *RefTestNode = CurDAG->getMachineNode(
279+
WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, {SigOp, FuncRef});
280+
ReplaceNode(Node, RefTestNode);
281+
return;
282+
}
192283
}
193284
break;
194285
}

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
798798

799799
if (IsIndirect) {
800800
// Placeholder for the type index.
801+
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
801802
MIB.addImm(0);
802803
// The table into which this call_indirect indexes.
803804
MCSymbolWasm *Table = IsFuncrefCall

llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515
#include "WebAssemblyMCInstLower.h"
1616
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
1717
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18+
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
1819
#include "TargetInfo/WebAssemblyTargetInfo.h"
1920
#include "Utils/WebAssemblyTypeUtilities.h"
2021
#include "WebAssemblyAsmPrinter.h"
2122
#include "WebAssemblyMachineFunctionInfo.h"
2223
#include "WebAssemblyUtilities.h"
24+
#include "llvm/ADT/APInt.h"
25+
#include "llvm/ADT/SmallVector.h"
26+
#include "llvm/BinaryFormat/Wasm.h"
2327
#include "llvm/CodeGen/AsmPrinter.h"
2428
#include "llvm/CodeGen/MachineFunction.h"
29+
#include "llvm/CodeGen/MachineOperand.h"
2530
#include "llvm/IR/Constants.h"
2631
#include "llvm/MC/MCAsmInfo.h"
2732
#include "llvm/MC/MCContext.h"
@@ -152,6 +157,34 @@ MCOperand WebAssemblyMCInstLower::lowerTypeIndexOperand(
152157
return MCOperand::createExpr(Expr);
153158
}
154159

160+
MCOperand
161+
WebAssemblyMCInstLower::lowerEncodedFunctionSignature(const APInt &Sig) const {
162+
// For APInt a word is 64 bits on all architectures, see definition in APInt.h
163+
auto NumWords = Sig.getNumWords();
164+
SmallVector<wasm::ValType, 4> Params;
165+
SmallVector<wasm::ValType, 2> Returns;
166+
167+
int Idx = NumWords;
168+
auto GetWord = [&Idx, &Sig]() {
169+
Idx--;
170+
return Sig.extractBitsAsZExtValue(64, 64 * Idx);
171+
};
172+
// Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
173+
// emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
174+
// always emit a CImm. So xor NParams with 0x7ffffff to ensure
175+
// getSignificantBits() > 64
176+
// See encodeFunctionSignature in WebAssemblyISelDAGtoDAG.cpp
177+
int NReturns = GetWord() ^ 0x7ffffff;
178+
for (int I = 0; I < NReturns; I++) {
179+
Returns.push_back(static_cast<wasm::ValType>(GetWord()));
180+
}
181+
int NParams = GetWord();
182+
for (int I = 0; I < NParams; I++) {
183+
Params.push_back(static_cast<wasm::ValType>(GetWord()));
184+
}
185+
return lowerTypeIndexOperand(std::move(Returns), std::move(Params));
186+
}
187+
155188
static void getFunctionReturns(const MachineInstr *MI,
156189
SmallVectorImpl<wasm::ValType> &Returns) {
157190
const Function &F = MI->getMF()->getFunction();
@@ -196,11 +229,29 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
196229
MCOp = MCOperand::createReg(WAReg);
197230
break;
198231
}
232+
case llvm::MachineOperand::MO_CImmediate: {
233+
// Lower type index placeholder for ref.test
234+
// Currently this is the only way that CImmediates show up so panic if we
235+
// get confused.
236+
unsigned DescIndex = I - NumVariadicDefs;
237+
assert(DescIndex < Desc.NumOperands && "unexpected CImmediate operand");
238+
auto Operands = Desc.operands();
239+
const MCOperandInfo &Info = Operands[DescIndex];
240+
assert(Info.OperandType == WebAssembly::OPERAND_TYPEINDEX &&
241+
"unexpected CImmediate operand");
242+
MCOp = lowerEncodedFunctionSignature(MO.getCImm()->getValue());
243+
break;
244+
}
199245
case MachineOperand::MO_Immediate: {
200246
unsigned DescIndex = I - NumVariadicDefs;
201247
if (DescIndex < Desc.NumOperands) {
202-
const MCOperandInfo &Info = Desc.operands()[DescIndex];
248+
auto Operands = Desc.operands();
249+
const MCOperandInfo &Info = Operands[DescIndex];
250+
// Replace type index placeholder with actual type index. The type index
251+
// placeholders are Immediates and have an operand type of
252+
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
203253
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
254+
// Lower type index placeholder for a CALL_INDIRECT instruction
204255
SmallVector<wasm::ValType, 4> Returns;
205256
SmallVector<wasm::ValType, 4> Params;
206257

@@ -228,6 +279,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
228279
break;
229280
}
230281
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
282+
// Lower type index placeholder for blocks
231283
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
232284
assert(BT != WebAssembly::BlockType::Invalid);
233285
if (BT == WebAssembly::BlockType::Multivalue) {

llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyMCInstLower {
3636
MCOperand lowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym) const;
3737
MCOperand lowerTypeIndexOperand(SmallVectorImpl<wasm::ValType> &&,
3838
SmallVectorImpl<wasm::ValType> &&) const;
39+
MCOperand lowerEncodedFunctionSignature(const APInt &Sig) const;
3940

4041
public:
4142
WebAssemblyMCInstLower(MCContext &ctx, WebAssemblyAsmPrinter &printer)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s --mtriple=wasm32-unknown-unknown -mcpu=mvp -mattr=+reference-types | FileCheck --check-prefixes CHECK,CHK32 %s
3+
; RUN: llc < %s --mtriple=wasm64-unknown-unknown -mcpu=mvp -mattr=+reference-types | FileCheck --check-prefixes CHECK,CHK64 %s
4+
5+
define void @test_fpsig_void_void(ptr noundef %func) local_unnamed_addr #0 {
6+
; CHECK-LABEL: test_fpsig_void_void:
7+
; CHK32: .functype test_fpsig_void_void (i32) -> ()
8+
; CHK64: .functype test_fpsig_void_void (i64) -> ()
9+
; CHECK-NEXT: # %bb.0: # %entry
10+
; CHECK-NEXT: local.get 0
11+
; CHECK-NEXT: table.get __indirect_function_table
12+
; CHECK-NEXT: ref.test () -> ()
13+
; CHECK-NEXT: call use
14+
; CHECK-NEXT: # fallthrough-return
15+
entry:
16+
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func)
17+
tail call void @use(i32 noundef %res) #3
18+
ret void
19+
}
20+
21+
define void @test_fpsig_return_i32(ptr noundef %func) local_unnamed_addr #0 {
22+
; CHECK-LABEL: test_fpsig_return_i32:
23+
; CHK32: .functype test_fpsig_return_i32 (i32) -> ()
24+
; CHK64: .functype test_fpsig_return_i32 (i64) -> ()
25+
; CHECK-NEXT: # %bb.0: # %entry
26+
; CHECK-NEXT: local.get 0
27+
; CHECK-NEXT: table.get __indirect_function_table
28+
; CHECK-NEXT: ref.test () -> (i32)
29+
; CHECK-NEXT: call use
30+
; CHECK-NEXT: # fallthrough-return
31+
entry:
32+
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0)
33+
tail call void @use(i32 noundef %res) #3
34+
ret void
35+
}
36+
37+
define void @test_fpsig_return_i64(ptr noundef %func) local_unnamed_addr #0 {
38+
; CHECK-LABEL: test_fpsig_return_i64:
39+
; CHK32: .functype test_fpsig_return_i64 (i32) -> ()
40+
; CHK64: .functype test_fpsig_return_i64 (i64) -> ()
41+
; CHECK-NEXT: # %bb.0: # %entry
42+
; CHECK-NEXT: local.get 0
43+
; CHECK-NEXT: table.get __indirect_function_table
44+
; CHECK-NEXT: ref.test () -> (i64)
45+
; CHECK-NEXT: call use
46+
; CHECK-NEXT: # fallthrough-return
47+
entry:
48+
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i64 0)
49+
tail call void @use(i32 noundef %res) #3
50+
ret void
51+
}
52+
53+
define void @test_fpsig_return_f32(ptr noundef %func) local_unnamed_addr #0 {
54+
; CHECK-LABEL: test_fpsig_return_f32:
55+
; CHK32: .functype test_fpsig_return_f32 (i32) -> ()
56+
; CHK64: .functype test_fpsig_return_f32 (i64) -> ()
57+
; CHECK-NEXT: # %bb.0: # %entry
58+
; CHECK-NEXT: local.get 0
59+
; CHECK-NEXT: table.get __indirect_function_table
60+
; CHECK-NEXT: ref.test () -> (f32)
61+
; CHECK-NEXT: call use
62+
; CHECK-NEXT: # fallthrough-return
63+
entry:
64+
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.)
65+
tail call void @use(i32 noundef %res) #3
66+
ret void
67+
}
68+
69+
define void @test_fpsig_return_f64(ptr noundef %func) local_unnamed_addr #0 {
70+
; CHECK-LABEL: test_fpsig_return_f64:
71+
; CHK32: .functype test_fpsig_return_f64 (i32) -> ()
72+
; CHK64: .functype test_fpsig_return_f64 (i64) -> ()
73+
; CHECK-NEXT: # %bb.0: # %entry
74+
; CHECK-NEXT: local.get 0
75+
; CHECK-NEXT: table.get __indirect_function_table
76+
; CHECK-NEXT: ref.test () -> (f64)
77+
; CHECK-NEXT: call use
78+
; CHECK-NEXT: # fallthrough-return
79+
entry:
80+
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, double 0.)
81+
tail call void @use(i32 noundef %res) #3
82+
ret void
83+
}
84+
85+
86+
define void @test_fpsig_param_i32(ptr noundef %func) local_unnamed_addr #0 {
87+
; CHECK-LABEL: test_fpsig_param_i32:
88+
; CHK32: .functype test_fpsig_param_i32 (i32) -> ()
89+
; CHK64: .functype test_fpsig_param_i32 (i64) -> ()
90+
; CHECK-NEXT: # %bb.0: # %entry
91+
; CHECK-NEXT: local.get 0
92+
; CHECK-NEXT: table.get __indirect_function_table
93+
; CHECK-NEXT: ref.test (f64) -> ()
94+
; CHECK-NEXT: call use
95+
; CHECK-NEXT: # fallthrough-return
96+
entry:
97+
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, double 0.)
98+
tail call void @use(i32 noundef %res) #3
99+
ret void
100+
}
101+
102+
103+
define void @test_fpsig_multiple_params_and_returns(ptr noundef %func) local_unnamed_addr #0 {
104+
; CHECK-LABEL: test_fpsig_multiple_params_and_returns:
105+
; CHK32: .functype test_fpsig_multiple_params_and_returns (i32) -> ()
106+
; CHK64: .functype test_fpsig_multiple_params_and_returns (i64) -> ()
107+
; CHECK-NEXT: # %bb.0: # %entry
108+
; CHECK-NEXT: local.get 0
109+
; CHECK-NEXT: table.get __indirect_function_table
110+
; CHECK-NEXT: ref.test (i64, f32, i64) -> (i32, i64, f32, f64)
111+
; CHECK-NEXT: call use
112+
; CHECK-NEXT: # fallthrough-return
113+
entry:
114+
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i64 0, float 0., double 0., token poison, i64 0, float 0., i64 0)
115+
tail call void @use(i32 noundef %res) #3
116+
ret void
117+
}
118+
119+
120+
declare void @use(i32 noundef) local_unnamed_addr #1

0 commit comments

Comments
 (0)