Skip to content

Commit 2771974

Browse files
authored
[orc-rt] Add WrapperFunction::handle support for fns, fn-ptrs. (#157787)
Adds support for using functions and function pointers to the WrapperFunction::handle utility.
1 parent 224cad6 commit 2771974

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

orc-rt/include/orc-rt/WrapperFunction.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ struct WFCallableTraits<RetT(ArgT, ArgTs...)> {
121121
typedef std::tuple<ArgTs...> TailArgTuple;
122122
};
123123

124+
template <typename RetT, typename... ArgTs>
125+
struct WFCallableTraits<RetT (*)(ArgTs...)>
126+
: public WFCallableTraits<RetT(ArgTs...)> {};
127+
128+
template <typename RetT, typename... ArgTs>
129+
struct WFCallableTraits<RetT (&)(ArgTs...)>
130+
: public WFCallableTraits<RetT(ArgTs...)> {};
131+
124132
template <typename ClassT, typename RetT, typename... ArgTs>
125133
struct WFCallableTraits<RetT (ClassT::*)(ArgTs...)>
126134
: public WFCallableTraits<RetT(ArgTs...)> {};

orc-rt/unittests/SPSWrapperFunctionTest.cpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,57 @@ TEST(SPSWrapperFunctionUtilsTest, TestVoidNoop) {
9090
EXPECT_TRUE(Ran);
9191
}
9292

93-
static void add_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
94-
orc_rt_WrapperFunctionReturn Return,
95-
orc_rt_WrapperFunctionBuffer ArgBytes) {
93+
static void add_via_lambda_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
94+
orc_rt_WrapperFunctionReturn Return,
95+
orc_rt_WrapperFunctionBuffer ArgBytes) {
9696
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
9797
Session, CallCtx, Return, ArgBytes,
9898
[](move_only_function<void(int32_t)> Return, int32_t X, int32_t Y) {
9999
Return(X + Y);
100100
});
101101
}
102102

103-
TEST(SPSWrapperFunctionUtilsTest, TestAdd) {
103+
TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaLambda) {
104104
int32_t Result = 0;
105105
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
106-
DirectCaller(nullptr, add_sps_wrapper),
106+
DirectCaller(nullptr, add_via_lambda_sps_wrapper),
107+
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
108+
EXPECT_EQ(Result, 42);
109+
}
110+
111+
static void add_via_function(move_only_function<void(int32_t)> Return,
112+
int32_t X, int32_t Y) {
113+
Return(X + Y);
114+
}
115+
116+
static void
117+
add_via_function_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
118+
orc_rt_WrapperFunctionReturn Return,
119+
orc_rt_WrapperFunctionBuffer ArgBytes) {
120+
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
121+
Session, CallCtx, Return, ArgBytes, add_via_function);
122+
}
123+
124+
TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunction) {
125+
int32_t Result = 0;
126+
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
127+
DirectCaller(nullptr, add_via_function_sps_wrapper),
128+
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
129+
EXPECT_EQ(Result, 42);
130+
}
131+
132+
static void
133+
add_via_function_pointer_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
134+
orc_rt_WrapperFunctionReturn Return,
135+
orc_rt_WrapperFunctionBuffer ArgBytes) {
136+
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
137+
Session, CallCtx, Return, ArgBytes, &add_via_function);
138+
}
139+
140+
TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) {
141+
int32_t Result = 0;
142+
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
143+
DirectCaller(nullptr, add_via_function_pointer_sps_wrapper),
107144
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
108145
EXPECT_EQ(Result, 42);
109146
}

0 commit comments

Comments
 (0)