diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h index 47e770f0bfbf7..c43f1c5b4a753 100644 --- a/orc-rt/include/orc-rt/WrapperFunction.h +++ b/orc-rt/include/orc-rt/WrapperFunction.h @@ -16,7 +16,9 @@ #include "orc-rt-c/WrapperFunction.h" #include "orc-rt/CallableTraitsHelper.h" #include "orc-rt/Error.h" +#include "orc-rt/ExecutorAddress.h" #include "orc-rt/bind.h" +#include "orc-rt/move_only_function.h" #include @@ -205,6 +207,128 @@ struct ResultDeserializer, Serializer> { /// wrapper functions in C++. struct WrapperFunction { + /// Wraps an asynchronous method (a method returning void, and taking a + /// return callback as its first argument) for use with + /// WrapperFunction::handle. + /// + /// AsyncMethod's call operator takes an ExecutorAddr as its second argument, + /// casts it to a ClassT*, and then calls the wrapped method on that pointer, + /// forwarding the return callback and any subsequent arguments (after the + /// second argument representing the object address). + /// + /// This utility removes some of the boilerplate from writing wrappers for + /// method calls. + template + struct AsyncMethod { + AsyncMethod(void (ClassT::*M)(ReturnT, ArgTs...)) : M(M) {} + void operator()(ReturnT &&Return, ExecutorAddr Obj, ArgTs &&...Args) { + (Obj.toPtr()->*M)(std::forward(Return), + std::forward(Args)...); + } + + private: + void (ClassT::*M)(ReturnT, ArgTs...); + }; + + /// Create an AsyncMethod wrapper for the given method pointer. The given + /// method should be asynchronous: returning void, and taking a return + /// callback as its first argument. + /// + /// The handWithAsyncMethod function can be used to remove some of the + /// boilerplate from writing wrappers for method calls: + /// + /// @code{.cpp} + /// class MyClass { + /// public: + /// void myMethod(move_only_function Return, + // uint32_t X, bool Y) { ... } + /// }; + /// + /// // SPS Method signature -- note MyClass object address as first + /// // argument. + /// using SPSMyMethodWrapperSignature = + /// SPSString(SPSExecutorAddr, uint32_t, bool); + /// + /// + /// static void adder_add_async_sps_wrapper( + /// orc_rt_SessionRef Session, void *CallCtx, + /// orc_rt_WrapperFunctionReturn Return, + /// orc_rt_WrapperFunctionBuffer ArgBytes) { + /// using SPSSig = SPSString(SPSExecutorAddr, int32_t, bool); + /// SPSWrapperFunction::handle( + /// Session, CallCtx, Return, ArgBytes, + /// WrapperFunction::handleWithAsyncMethod(&MyClass::myMethod)); + /// } + /// @endcode + /// + template + static AsyncMethod + handleWithAsyncMethod(void (ClassT::*M)(ReturnT, ArgTs...)) { + return AsyncMethod(M); + } + + /// Wraps a synchronous method (an ordinary method that returns its result, + /// as opposed to an asynchronous method, see AsyncMethod) for use with + /// WrapperFunction::handle. + /// + /// SyncMethod's call operator takes a return callback as its first argument + /// and an ExecutorAddr as its second argument. The ExecutorAddr argument is + /// cast to a ClassT*, and then called passing the subsequent arguments + /// (after the second argument representing the object address). The Return + /// callback is then called on the value returned from the method. + /// + /// This utility removes some of the boilerplate from writing wrappers for + /// method calls. + template + class SyncMethod { + public: + SyncMethod(RetT (ClassT::*M)(ArgTs...)) : M(M) {} + + void operator()(move_only_function Return, ExecutorAddr Obj, + ArgTs &&...Args) { + Return((Obj.toPtr()->*M)(std::forward(Args)...)); + } + + private: + RetT (ClassT::*M)(ArgTs...); + }; + + /// Create an SyncMethod wrapper for the given method pointer. The given + /// method should be synchronous, i.e. returning its result (as opposed to + /// asynchronous, see AsyncMethod). + /// + /// The handWithAsyncMethod function can be used to remove some of the + /// boilerplate from writing wrappers for method calls: + /// + /// @code{.cpp} + /// class MyClass { + /// public: + /// std::string myMethod(uint32_t X, bool Y) { ... } + /// }; + /// + /// // SPS Method signature -- note MyClass object address as first + /// // argument. + /// using SPSMyMethodWrapperSignature = + /// SPSString(SPSExecutorAddr, uint32_t, bool); + /// + /// + /// static void adder_add_sync_sps_wrapper( + /// orc_rt_SessionRef Session, void *CallCtx, + /// orc_rt_WrapperFunctionReturn Return, + /// orc_rt_WrapperFunctionBuffer ArgBytes) { + /// using SPSSig = SPSString(SPSExecutorAddr, int32_t, bool); + /// SPSWrapperFunction::handle( + /// Session, CallCtx, Return, ArgBytes, + /// WrapperFunction::handleWithSyncMethod(&Adder::addSync)); + /// } + /// @endcode + /// + template + static SyncMethod + handleWithSyncMethod(RetT (ClassT::*M)(ArgTs...)) { + return SyncMethod(M); + } + /// Make a call to a wrapper function. /// /// This utility serializes and deserializes arguments and return values diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp index 32aaa61639dbb..e010e2a067adf 100644 --- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp +++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp @@ -297,3 +297,53 @@ TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) { EXPECT_EQ(OpCounter<3>::moves(), 1U); EXPECT_EQ(OpCounter<3>::copies(), 0U); } + +namespace { +class Adder { +public: + int32_t addSync(int32_t X, int32_t Y) { return X + Y; } + void addAsync(move_only_function Return, int32_t X, + int32_t Y) { + Return(addSync(X, Y)); + } +}; +} // anonymous namespace + +static void adder_add_async_sps_wrapper(orc_rt_SessionRef Session, + void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction::handle( + Session, CallCtx, Return, ArgBytes, + WrapperFunction::handleWithAsyncMethod(&Adder::addAsync)); +} + +TEST(SPSWrapperFunctionUtilsTest, HandleWtihAsyncMethod) { + auto A = std::make_unique(); + int32_t Result = 0; + SPSWrapperFunction::call( + DirectCaller(nullptr, adder_add_async_sps_wrapper), + [&](Expected R) { Result = cantFail(std::move(R)); }, + ExecutorAddr::fromPtr(A.get()), 41, 1); + + EXPECT_EQ(Result, 42); +} + +static void adder_add_sync_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction::handle( + Session, CallCtx, Return, ArgBytes, + WrapperFunction::handleWithSyncMethod(&Adder::addSync)); +} + +TEST(SPSWrapperFunctionUtilsTest, HandleWithSyncMethod) { + auto A = std::make_unique(); + int32_t Result = 0; + SPSWrapperFunction::call( + DirectCaller(nullptr, adder_add_sync_sps_wrapper), + [&](Expected R) { Result = cantFail(std::move(R)); }, + ExecutorAddr::fromPtr(A.get()), 41, 1); + + EXPECT_EQ(Result, 42); +}