Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;

def int_wasm_ref_test_func
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
[IntrNoMem], "llvm.wasm.ref.test.func">;

//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
IsDebug, IsClone, IsCloned);
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
MIB.addImm(C->getSExtValue());
if (C->getAPIntValue().getBitWidth() <= 64) {
MIB.addImm(C->getSExtValue());
} else {
MIB.addCImm(
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
}
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
MIB.addFPImm(F->getConstantFPValue());
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {
Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
Expand Down Expand Up @@ -794,6 +795,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,

if (IsIndirect) {
// Placeholder for the type index.
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
MIB.addImm(0);
// The table into which this call_indirect indexes.
MCSymbolWasm *Table = IsFuncrefCall
Expand Down Expand Up @@ -2253,6 +2255,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
0);
}
case Intrinsic::wasm_ref_test_func: {
// First emit the TABLE_GET instruction to convert function pointer ==>
// funcref
MachineFunction &MF = DAG.getMachineFunction();
auto PtrVT = getPointerTy(MF.getDataLayout());
MCSymbol *Table =
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
SDValue FuncRef =
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
MVT::funcref, TableSym, Op.getOperand(1)),
0);

// Encode the signature information into the type index placeholder.
// This gets decoded and converted into the actual type signature in
// WebAssemblyMCInstLower.cpp.
auto NParams = Op.getNumOperands() - 2;
auto BitWidth = (NParams + 1) * 64;
auto Sig = APInt(BitWidth, 0);
// The return type has to be a BlockType since it can be void.
{
SDValue Operand = Op.getOperand(2);
MVT VT = Operand.getValueType().getSimpleVT();
WebAssembly::BlockType V;
if (VT == MVT::Untyped) {
V = WebAssembly::BlockType::Void;
} else if (VT == MVT::i32) {
V = WebAssembly::BlockType::I32;
} else if (VT == MVT::i64) {
V = WebAssembly::BlockType::I64;
} else if (VT == MVT::f32) {
V = WebAssembly::BlockType::F32;
} else if (VT == MVT::f64) {
V = WebAssembly::BlockType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Sig |= (int64_t)V;
}
for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
SDValue Operand = Op.getOperand(i);
MVT VT = Operand.getValueType().getSimpleVT();
wasm::ValType V;
if (VT == MVT::i32) {
V = wasm::ValType::I32;
} else if (VT == MVT::i64) {
V = wasm::ValType::I64;
} else if (VT == MVT::f32) {
V = wasm::ValType::F32;
} else if (VT == MVT::f64) {
V = wasm::ValType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Sig <<= 64;
Sig |= (int64_t)V;
}

SmallVector<SDValue, 4> Ops;
Ops.push_back(DAG.getTargetConstant(
Sig, DL, EVT::getIntegerVT(*DAG.getContext(), BitWidth)));
Ops.push_back(FuncRef);
return SDValue(
DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, Ops),
0);
}
}
}

Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
#include "WebAssemblyMCInstLower.h"
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
#include "TargetInfo/WebAssemblyTargetInfo.h"
#include "Utils/WebAssemblyTypeUtilities.h"
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/IR/Constants.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
Expand Down Expand Up @@ -196,11 +200,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
MCOp = MCOperand::createReg(WAReg);
break;
}
case llvm::MachineOperand::MO_CImmediate: {
// Lower type index placeholder for ref.test
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex >= Desc.NumOperands) {
llvm_unreachable("unexpected CImmediate operand");
}
const MCOperandInfo &Info = Desc.operands()[DescIndex];
if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
llvm_unreachable("unexpected CImmediate operand");
}
auto CImm = MO.getCImm()->getValue();
auto NumWords = CImm.getNumWords() - 1;
// Extract the type data we packed into the CImm in LowerRefTestFuncRef.
// We need to load the words from most significant to least significant
// order because of the way we bitshifted them in from the right.
// The return type needs special handling because it could be void.
auto ReturnType = static_cast<WebAssembly::BlockType>(
CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
SmallVector<wasm::ValType, 2> Returns;
switch (ReturnType) {
case WebAssembly::BlockType::Invalid:
llvm_unreachable("Invalid return type");
case WebAssembly::BlockType::I32:
Returns = {wasm::ValType::I32};
break;
case WebAssembly::BlockType::I64:
Returns = {wasm::ValType::I64};
break;
case WebAssembly::BlockType::F32:
Returns = {wasm::ValType::F32};
break;
case WebAssembly::BlockType::F64:
Returns = {wasm::ValType::F64};
break;
case WebAssembly::BlockType::Void:
Returns = {};
break;
case WebAssembly::BlockType::Exnref:
Returns = {wasm::ValType::EXNREF};
break;
case WebAssembly::BlockType::Externref:
Returns = {wasm::ValType::EXTERNREF};
break;
case WebAssembly::BlockType::Funcref:
Returns = {wasm::ValType::FUNCREF};
break;
case WebAssembly::BlockType::V128:
Returns = {wasm::ValType::V128};
break;
case WebAssembly::BlockType::Multivalue: {
llvm_unreachable("Invalid return type");
}
}
SmallVector<wasm::ValType, 4> Params;

for (int I = NumWords - 2; I >= 0; I--) {
auto Val = CImm.extractBitsAsZExtValue(64, 64 * I);
auto ParamType = static_cast<wasm::ValType>(Val);
Params.push_back(ParamType);
}
MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
break;
}
case MachineOperand::MO_Immediate: {
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex < Desc.NumOperands) {
const MCOperandInfo &Info = Desc.operands()[DescIndex];
// Replace type index placeholder with actual type index. The type index
// placeholders are Immediates and have an operand type of
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
// Lower type index placeholder for a CALL_INDIRECT instruction
SmallVector<wasm::ValType, 4> Returns;
SmallVector<wasm::ValType, 4> Params;

Expand Down Expand Up @@ -228,6 +301,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
// Lower type index placeholder for blocks
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
assert(BT != WebAssembly::BlockType::Invalid);
if (BT == WebAssembly::BlockType::Multivalue) {
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/WebAssembly/ref-test-func.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s

target triple = "wasm32-unknown-unknown"

; CHECK-LABEL: test_function_pointer_signature_void:
; CHECK-NEXT: .functype test_function_pointer_signature_void (i32) -> ()
; CHECK-NEXT: .local funcref
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: local.tee 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> ()
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test () -> ()
; CHECK-NEXT: call use

; Function Attrs: nounwind
define void @test_function_pointer_signature_void(ptr noundef %func) local_unnamed_addr #0 {
entry:
%0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %0) #3
%1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %1) #3
%2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %2) #3
%3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %3) #3
%4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison)
tail call void @use(i32 noundef %4) #3
ret void
}

declare void @use(i32 noundef) local_unnamed_addr #1
Loading