Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions orc-rt/include/orc-rt/WrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "orc-rt-c/WrapperFunction.h"
#include "orc-rt/CallableTraitsHelper.h"
#include "orc-rt/Error.h"
#include "orc-rt/ExecutorAddress.h"
#include "orc-rt/bind.h"
#include "orc-rt/move_only_function.h"

#include <utility>

Expand Down Expand Up @@ -205,6 +207,128 @@ struct ResultDeserializer<std::tuple<Error>, Serializer> {
/// wrapper functions in C++.
struct WrapperFunction {

/// Wraps an asynchronous method (a method returning void, and taking a
/// return callback as its first argument) for use with
/// WrapperFunction::handle.
///
/// AsyncMethod's call operator takes an ExecutorAddr as its second argument,
/// casts it to a ClassT*, and then calls the wrapped method on that pointer,
/// forwarding the return callback and any subsequent arguments (after the
/// second argument representing the object address).
///
/// This utility removes some of the boilerplate from writing wrappers for
/// method calls.
template <typename ClassT, typename ReturnT, typename... ArgTs>
struct AsyncMethod {
AsyncMethod(void (ClassT::*M)(ReturnT, ArgTs...)) : M(M) {}
void operator()(ReturnT &&Return, ExecutorAddr Obj, ArgTs &&...Args) {
(Obj.toPtr<ClassT *>()->*M)(std::forward<ReturnT>(Return),
std::forward<ArgTs>(Args)...);
}

private:
void (ClassT::*M)(ReturnT, ArgTs...);
};

/// Create an AsyncMethod wrapper for the given method pointer. The given
/// method should be asynchronous: returning void, and taking a return
/// callback as its first argument.
///
/// The handWithAsyncMethod function can be used to remove some of the
/// boilerplate from writing wrappers for method calls:
///
/// @code{.cpp}
/// class MyClass {
/// public:
/// void myMethod(move_only_function<void(std::string)> Return,
// uint32_t X, bool Y) { ... }
/// };
///
/// // SPS Method signature -- note MyClass object address as first
/// // argument.
/// using SPSMyMethodWrapperSignature =
/// SPSString(SPSExecutorAddr, uint32_t, bool);
///
///
/// static void adder_add_async_sps_wrapper(
/// orc_rt_SessionRef Session, void *CallCtx,
/// orc_rt_WrapperFunctionReturn Return,
/// orc_rt_WrapperFunctionBuffer ArgBytes) {
/// using SPSSig = SPSString(SPSExecutorAddr, int32_t, bool);
/// SPSWrapperFunction<SPSSig>::handle(
/// Session, CallCtx, Return, ArgBytes,
/// WrapperFunction::handleWithAsyncMethod(&MyClass::myMethod));
/// }
/// @endcode
///
template <typename ClassT, typename ReturnT, typename... ArgTs>
static AsyncMethod<ClassT, ReturnT, ArgTs...>
handleWithAsyncMethod(void (ClassT::*M)(ReturnT, ArgTs...)) {
return AsyncMethod<ClassT, ReturnT, ArgTs...>(M);
}

/// Wraps a synchronous method (an ordinary method that returns its result,
/// as opposed to an asynchronous method, see AsyncMethod) for use with
/// WrapperFunction::handle.
///
/// SyncMethod's call operator takes a return callback as its first argument
/// and an ExecutorAddr as its second argument. The ExecutorAddr argument is
/// cast to a ClassT*, and then called passing the subsequent arguments
/// (after the second argument representing the object address). The Return
/// callback is then called on the value returned from the method.
///
/// This utility removes some of the boilerplate from writing wrappers for
/// method calls.
template <typename ClassT, typename RetT, typename... ArgTs>
class SyncMethod {
public:
SyncMethod(RetT (ClassT::*M)(ArgTs...)) : M(M) {}

void operator()(move_only_function<void(RetT)> Return, ExecutorAddr Obj,
ArgTs &&...Args) {
Return((Obj.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...));
}

private:
RetT (ClassT::*M)(ArgTs...);
};

/// Create an SyncMethod wrapper for the given method pointer. The given
/// method should be synchronous, i.e. returning its result (as opposed to
/// asynchronous, see AsyncMethod).
///
/// The handWithAsyncMethod function can be used to remove some of the
/// boilerplate from writing wrappers for method calls:
///
/// @code{.cpp}
/// class MyClass {
/// public:
/// std::string myMethod(uint32_t X, bool Y) { ... }
/// };
///
/// // SPS Method signature -- note MyClass object address as first
/// // argument.
/// using SPSMyMethodWrapperSignature =
/// SPSString(SPSExecutorAddr, uint32_t, bool);
///
///
/// static void adder_add_sync_sps_wrapper(
/// orc_rt_SessionRef Session, void *CallCtx,
/// orc_rt_WrapperFunctionReturn Return,
/// orc_rt_WrapperFunctionBuffer ArgBytes) {
/// using SPSSig = SPSString(SPSExecutorAddr, int32_t, bool);
/// SPSWrapperFunction<SPSSig>::handle(
/// Session, CallCtx, Return, ArgBytes,
/// WrapperFunction::handleWithSyncMethod(&Adder::addSync));
/// }
/// @endcode
///
template <typename ClassT, typename RetT, typename... ArgTs>
static SyncMethod<ClassT, RetT, ArgTs...>
handleWithSyncMethod(RetT (ClassT::*M)(ArgTs...)) {
return SyncMethod<ClassT, RetT, ArgTs...>(M);
}

/// Make a call to a wrapper function.
///
/// This utility serializes and deserializes arguments and return values
Expand Down
50 changes: 50 additions & 0 deletions orc-rt/unittests/SPSWrapperFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,53 @@ TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) {
EXPECT_EQ(OpCounter<3>::moves(), 1U);
EXPECT_EQ(OpCounter<3>::copies(), 0U);
}

namespace {
class Adder {
public:
int32_t addSync(int32_t X, int32_t Y) { return X + Y; }
void addAsync(move_only_function<void(int32_t)> Return, int32_t X,
int32_t Y) {
Return(addSync(X, Y));
}
};
} // anonymous namespace

static void adder_add_async_sps_wrapper(orc_rt_SessionRef Session,
void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<int32_t(SPSExecutorAddr, int32_t, int32_t)>::handle(
Session, CallCtx, Return, ArgBytes,
WrapperFunction::handleWithAsyncMethod(&Adder::addAsync));
}

TEST(SPSWrapperFunctionUtilsTest, HandleWtihAsyncMethod) {
auto A = std::make_unique<Adder>();
int32_t Result = 0;
SPSWrapperFunction<int32_t(SPSExecutorAddr, int32_t, int32_t)>::call(
DirectCaller(nullptr, adder_add_async_sps_wrapper),
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); },
ExecutorAddr::fromPtr(A.get()), 41, 1);

EXPECT_EQ(Result, 42);
}

static void adder_add_sync_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<int32_t(SPSExecutorAddr, int32_t, int32_t)>::handle(
Session, CallCtx, Return, ArgBytes,
WrapperFunction::handleWithSyncMethod(&Adder::addSync));
}

TEST(SPSWrapperFunctionUtilsTest, HandleWithSyncMethod) {
auto A = std::make_unique<Adder>();
int32_t Result = 0;
SPSWrapperFunction<int32_t(SPSExecutorAddr, int32_t, int32_t)>::call(
DirectCaller(nullptr, adder_add_sync_sps_wrapper),
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); },
ExecutorAddr::fromPtr(A.get()), 41, 1);

EXPECT_EQ(Result, 42);
}
Loading