Skip to content

Commit 811c31e

Browse files
lhamesMixedMatched
authored andcommitted
[orc-rt] Add transparent SPS conversion for error/expected types. (llvm#161768)
This commit aims to reduce boilerplate by adding transparent conversion between Error/Expected types and their SPS-serializable counterparts (SPSSerializableError/SPSSerializableExpected). This allows SPSWrapperFunction calls and handles to be written in terms of Error/Expected directly. This functionality can also be extended to transparently convert between other types. This may be used in the future to provide conversion between ExecutorAddr and native pointer types.
1 parent fbb8c01 commit 811c31e

File tree

3 files changed

+129
-7
lines changed

3 files changed

+129
-7
lines changed

orc-rt/include/orc-rt/SPSWrapperFunction.h

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ namespace orc_rt {
2121
namespace detail {
2222

2323
template <typename... SPSArgTs> struct WFSPSHelper {
24-
template <typename... ArgTs>
25-
std::optional<WrapperFunctionBuffer> serialize(const ArgTs &...Args) {
24+
private:
25+
template <typename... SerializableArgTs>
26+
std::optional<WrapperFunctionBuffer>
27+
serializeImpl(const SerializableArgTs &...Args) {
2628
auto R =
2729
WrapperFunctionBuffer::allocate(SPSArgList<SPSArgTs...>::size(Args...));
2830
SPSOutputBuffer OB(R.data(), R.size());
@@ -31,16 +33,61 @@ template <typename... SPSArgTs> struct WFSPSHelper {
3133
return std::move(R);
3234
}
3335

36+
template <typename T> static const T &toSerializable(const T &Arg) noexcept {
37+
return Arg;
38+
}
39+
40+
static SPSSerializableError toSerializable(Error Err) noexcept {
41+
return SPSSerializableError(std::move(Err));
42+
}
43+
44+
template <typename T>
45+
static SPSSerializableExpected<T> toSerializable(Expected<T> Arg) noexcept {
46+
return SPSSerializableExpected<T>(std::move(Arg));
47+
}
48+
49+
template <typename... Ts> struct DeserializableTuple;
50+
51+
template <typename... Ts> struct DeserializableTuple<std::tuple<Ts...>> {
52+
typedef std::tuple<
53+
std::decay_t<decltype(toSerializable(std::declval<Ts>()))>...>
54+
type;
55+
};
56+
57+
template <typename... Ts>
58+
using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;
59+
60+
template <typename T> static T fromSerializable(T &&Arg) noexcept {
61+
return Arg;
62+
}
63+
64+
static Error fromSerializable(SPSSerializableError Err) noexcept {
65+
return Err.toError();
66+
}
67+
68+
template <typename T>
69+
static Expected<T> fromSerializable(SPSSerializableExpected<T> Val) noexcept {
70+
return Val.toExpected();
71+
}
72+
73+
public:
74+
template <typename... ArgTs>
75+
std::optional<WrapperFunctionBuffer> serialize(ArgTs &&...Args) {
76+
return serializeImpl(toSerializable(std::forward<ArgTs>(Args))...);
77+
}
78+
3479
template <typename ArgTuple>
3580
std::optional<ArgTuple> deserialize(WrapperFunctionBuffer ArgBytes) {
3681
assert(!ArgBytes.getOutOfBandError() &&
3782
"Should not attempt to deserialize out-of-band error");
3883
SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size());
39-
ArgTuple Args;
40-
if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>, ArgTuple>::deserialize(
41-
IB, Args))
84+
DeserializableTuple_t<ArgTuple> Args;
85+
if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>,
86+
decltype(Args)>::deserialize(IB, Args))
4287
return std::nullopt;
43-
return Args;
88+
return std::apply(
89+
[](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
90+
std::move(Args));
4491
}
4592
};
4693

orc-rt/include/orc-rt/WrapperFunction.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ struct ResultDeserializer<std::tuple<Expected<T>>, Serializer> {
168168
Serializer &S) {
169169
if (auto Val = S.result().template deserialize<std::tuple<T>>(
170170
std::move(ResultBytes)))
171-
return std::move(std::get<0>(*Val));
171+
return Expected<T>(std::move(std::get<0>(*Val)),
172+
ForceExpectedSuccessValue());
172173
else
173174
return make_error<StringError>("Could not deserialize result");
174175
}

orc-rt/unittests/SPSWrapperFunctionTest.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,77 @@ TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) {
144144
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
145145
EXPECT_EQ(Result, 42);
146146
}
147+
148+
static void improbable_feat_sps_wrapper(orc_rt_SessionRef Session,
149+
void *CallCtx,
150+
orc_rt_WrapperFunctionReturn Return,
151+
orc_rt_WrapperFunctionBuffer ArgBytes) {
152+
SPSWrapperFunction<SPSError(bool)>::handle(
153+
Session, CallCtx, Return, ArgBytes,
154+
[](move_only_function<void(Error)> Return, bool LuckyHat) {
155+
if (LuckyHat)
156+
Return(Error::success());
157+
else
158+
Return(make_error<StringError>("crushed by boulder"));
159+
});
160+
}
161+
162+
TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorSuccessCase) {
163+
bool DidRun = false;
164+
SPSWrapperFunction<SPSError(bool)>::call(
165+
DirectCaller(nullptr, improbable_feat_sps_wrapper),
166+
[&](Expected<Error> E) {
167+
DidRun = true;
168+
cantFail(cantFail(std::move(E)));
169+
},
170+
true);
171+
172+
EXPECT_TRUE(DidRun);
173+
}
174+
175+
TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorFailureCase) {
176+
std::string ErrMsg;
177+
SPSWrapperFunction<SPSError(bool)>::call(
178+
DirectCaller(nullptr, improbable_feat_sps_wrapper),
179+
[&](Expected<Error> E) { ErrMsg = toString(cantFail(std::move(E))); },
180+
false);
181+
182+
EXPECT_EQ(ErrMsg, "crushed by boulder");
183+
}
184+
185+
static void halve_number_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
186+
orc_rt_WrapperFunctionReturn Return,
187+
orc_rt_WrapperFunctionBuffer ArgBytes) {
188+
SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::handle(
189+
Session, CallCtx, Return, ArgBytes,
190+
[](move_only_function<void(Expected<int32_t>)> Return, int N) {
191+
if (N % 2 == 0)
192+
Return(N >> 1);
193+
else
194+
Return(make_error<StringError>("N is not a multiple of 2"));
195+
});
196+
}
197+
198+
TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedSuccessCase) {
199+
int32_t Result = 0;
200+
SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
201+
DirectCaller(nullptr, halve_number_sps_wrapper),
202+
[&](Expected<Expected<int32_t>> R) {
203+
Result = cantFail(cantFail(std::move(R)));
204+
},
205+
2);
206+
207+
EXPECT_EQ(Result, 1);
208+
}
209+
210+
TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
211+
std::string ErrMsg;
212+
SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
213+
DirectCaller(nullptr, halve_number_sps_wrapper),
214+
[&](Expected<Expected<int32_t>> R) {
215+
ErrMsg = toString(cantFail(std::move(R)).takeError());
216+
},
217+
3);
218+
219+
EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
220+
}

0 commit comments

Comments
 (0)