Skip to content

Commit e8489c1

Browse files
authored
[orc-rt] WrapperFunction::handle: add by-ref args, minimize temporaries. (#161999)
This adds support for WrapperFunction::handle handlers that take their arguments by reference, rather than by value. This commit also reduces the number of temporary objects created to support SPS-transparent conversion in SPSWrapperFunction.
1 parent 5284c83 commit e8489c1

File tree

3 files changed

+107
-8
lines changed

3 files changed

+107
-8
lines changed

orc-rt/include/orc-rt/SPSWrapperFunction.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ template <typename... SPSArgTs> struct WFSPSHelper {
5757
template <typename... Ts>
5858
using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;
5959

60-
template <typename T> static T fromSerializable(T &&Arg) noexcept {
61-
return Arg;
60+
template <typename T> static T &&fromSerializable(T &&Arg) noexcept {
61+
return std::forward<T>(Arg);
6262
}
6363

6464
static Error fromSerializable(SPSSerializableError Err) noexcept {
@@ -86,7 +86,10 @@ template <typename... SPSArgTs> struct WFSPSHelper {
8686
decltype(Args)>::deserialize(IB, Args))
8787
return std::nullopt;
8888
return std::apply(
89-
[](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
89+
[](auto &&...A) {
90+
return std::optional<ArgTuple>(std::in_place,
91+
std::move(fromSerializable(A))...);
92+
},
9093
std::move(Args));
9194
}
9295
};

orc-rt/include/orc-rt/WrapperFunction.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,23 @@ struct WFHandlerTraitsImpl {
111111
static_assert(std::is_void_v<RetT>,
112112
"Async wrapper function handler must return void");
113113
typedef ReturnT YieldType;
114-
typedef std::tuple<ArgTs...> ArgTupleType;
114+
typedef std::tuple<std::decay_t<ArgTs>...> ArgTupleType;
115+
116+
// Forwards arguments based on the parameter types of the handler.
117+
template <typename FnT> class ForwardArgsAsRequested {
118+
public:
119+
ForwardArgsAsRequested(FnT &&Fn) : Fn(std::move(Fn)) {}
120+
void operator()(ArgTs &...Args) { Fn(std::forward<ArgTs>(Args)...); }
121+
122+
private:
123+
FnT Fn;
124+
};
125+
126+
template <typename FnT>
127+
static ForwardArgsAsRequested<std::decay_t<FnT>>
128+
forwardArgsAsRequested(FnT &&Fn) {
129+
return ForwardArgsAsRequested<std::decay_t<FnT>>(std::forward<FnT>(Fn));
130+
}
115131
};
116132

117133
template <typename C>
@@ -244,10 +260,11 @@ struct WrapperFunction {
244260

245261
if (auto Args =
246262
S.arguments().template deserialize<ArgTuple>(std::move(ArgBytes)))
247-
std::apply(bind_front(std::forward<Handler>(H),
248-
detail::StructuredYield<RetTupleType, Serializer>(
249-
Session, CallCtx, Return, std::move(S))),
250-
std::move(*Args));
263+
std::apply(HandlerTraits::forwardArgsAsRequested(bind_front(
264+
std::forward<Handler>(H),
265+
detail::StructuredYield<RetTupleType, Serializer>(
266+
Session, CallCtx, Return, std::move(S)))),
267+
*Args);
251268
else
252269
Return(Session, CallCtx,
253270
WrapperFunctionBuffer::createOutOfBandError(

orc-rt/unittests/SPSWrapperFunctionTest.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "CommonTestUtils.h"
14+
1315
#include "orc-rt/SPSWrapperFunction.h"
1416
#include "orc-rt/WrapperFunction.h"
1517
#include "orc-rt/move_only_function.h"
@@ -218,3 +220,80 @@ TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
218220

219221
EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
220222
}
223+
224+
template <size_t N> struct SPSOpCounter {};
225+
226+
namespace orc_rt {
227+
template <size_t N>
228+
class SPSSerializationTraits<SPSOpCounter<N>, OpCounter<N>> {
229+
public:
230+
static size_t size(const OpCounter<N> &O) { return 0; }
231+
static bool serialize(SPSOutputBuffer &OB, const OpCounter<N> &O) {
232+
return true;
233+
}
234+
static bool deserialize(SPSInputBuffer &OB, OpCounter<N> &O) { return true; }
235+
};
236+
} // namespace orc_rt
237+
238+
static void
239+
handle_with_reference_types_sps_wrapper(orc_rt_SessionRef Session,
240+
void *CallCtx,
241+
orc_rt_WrapperFunctionReturn Return,
242+
orc_rt_WrapperFunctionBuffer ArgBytes) {
243+
SPSWrapperFunction<void(
244+
SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
245+
SPSOpCounter<3>)>::handle(Session, CallCtx, Return, ArgBytes,
246+
[](move_only_function<void()> Return,
247+
OpCounter<0>, OpCounter<1> &,
248+
const OpCounter<2> &,
249+
OpCounter<3> &&) { Return(); });
250+
}
251+
252+
TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) {
253+
// Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref
254+
// arguments, and that we generate the expected number of moves.
255+
OpCounter<0>::reset();
256+
OpCounter<1>::reset();
257+
OpCounter<2>::reset();
258+
OpCounter<3>::reset();
259+
260+
bool DidRun = false;
261+
SPSWrapperFunction<void(SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
262+
SPSOpCounter<3>)>::
263+
call(
264+
DirectCaller(nullptr, handle_with_reference_types_sps_wrapper),
265+
[&](Error R) {
266+
cantFail(std::move(R));
267+
DidRun = true;
268+
},
269+
OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>());
270+
271+
EXPECT_TRUE(DidRun);
272+
273+
// We expect two default constructions for each parameter: one for the
274+
// argument to call, and one for the object to deserialize into.
275+
EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U);
276+
EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U);
277+
EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U);
278+
EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U);
279+
280+
// Pass-by-value: we expect two moves (one for SPS transparent conversion,
281+
// one to copy the value to the parameter), and no copies.
282+
EXPECT_EQ(OpCounter<0>::moves(), 2U);
283+
EXPECT_EQ(OpCounter<0>::copies(), 0U);
284+
285+
// Pass-by-lvalue-reference: we expect one move (for SPS transparent
286+
// conversion), no copies.
287+
EXPECT_EQ(OpCounter<1>::moves(), 1U);
288+
EXPECT_EQ(OpCounter<1>::copies(), 0U);
289+
290+
// Pass-by-const-lvalue-reference: we expect one move (for SPS transparent
291+
// conversion), no copies.
292+
EXPECT_EQ(OpCounter<2>::moves(), 1U);
293+
EXPECT_EQ(OpCounter<2>::copies(), 0U);
294+
295+
// Pass-by-rvalue-reference: we expect one move (for SPS transparent
296+
// conversion), no copies.
297+
EXPECT_EQ(OpCounter<3>::moves(), 1U);
298+
EXPECT_EQ(OpCounter<3>::copies(), 0U);
299+
}

0 commit comments

Comments
 (0)