Skip to content

Commit 804f7a2

Browse files
authored
DPL: finally use concepts to separate the make method (AliceO2Group#13611)
1 parent ffc81fd commit 804f7a2

File tree

1 file changed

+104
-85
lines changed

1 file changed

+104
-85
lines changed

Framework/Core/include/Framework/DataAllocator.h

Lines changed: 104 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <TClass.h>
3232
#include <gsl/span>
3333

34+
#include <memory>
3435
#include <vector>
3536
#include <map>
3637
#include <string>
@@ -57,14 +58,6 @@ namespace o2::framework
5758
{
5859
struct ServiceRegistry;
5960

60-
#define ERROR_STRING \
61-
"data type T not supported by API, " \
62-
"\n specializations available for" \
63-
"\n - trivially copyable, non-polymorphic structures" \
64-
"\n - arrays of those" \
65-
"\n - TObject with additional constructor arguments" \
66-
"\n - std containers of those"
67-
6861
/// Helper to allow framework managed objecs to have a callback
6962
/// when they go out of scope. For example, this could
7063
/// be used to serialize a message into a buffer before the
@@ -130,6 +123,10 @@ struct LifetimeHolder {
130123
}
131124
};
132125

126+
template <typename T>
127+
concept VectorOfMessageableTypes = is_specialization_v<T, std::vector> &&
128+
is_messageable<typename T::value_type>::value;
129+
133130
/// This allocator is responsible to make sure that the messages created match
134131
/// the provided spec and that depending on how many pipelined reader we
135132
/// have, messages get created on the channel for the reader of the current
@@ -143,6 +140,7 @@ class DataAllocator
143140
using DataOrigin = o2::header::DataOrigin;
144141
using DataDescription = o2::header::DataDescription;
145142
using SubSpecificationType = o2::header::DataHeader::SubSpecificationType;
143+
146144
template <typename T>
147145
requires std::is_fundamental_v<T>
148146
struct UninitializedVector {
@@ -163,93 +161,114 @@ class DataAllocator
163161
// and with subspecification 0xdeadbeef.
164162
void cookDeadBeef(const Output& spec);
165163

166-
/// Generic helper to create an object which is owned by the framework and
167-
/// returned as a reference to the own object.
168-
/// Note: decltype(auto) will deduce the return type from the expression and it
169-
/// will be lvalue reference for the framework-owned objects. Instances of local
170-
/// variables like shared_ptr will be returned by value/move/return value optimization.
171-
/// Objects created this way will be sent to the channel specified by @spec
172164
template <typename T, typename... Args>
165+
requires is_specialization_v<T, o2::framework::DataAllocator::UninitializedVector>
173166
decltype(auto) make(const Output& spec, Args... args)
174167
{
175168
auto& timingInfo = mRegistry.get<TimingInfo>();
176169
auto& context = mRegistry.get<MessageContext>();
177170

178-
if constexpr (is_specialization_v<T, UninitializedVector>) {
179-
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
180-
// plain buffer as polymorphic spectator std::vector, which does not run constructors / destructors
181-
using ValueType = typename T::value_type;
182-
183-
// Note: initial payload size is 0 and will be set by the context before sending
184-
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
185-
return context.add<MessageContext::VectorObject<ValueType, MessageContext::ContainerRefObject<std::vector<ValueType, o2::pmr::NoConstructAllocator<ValueType>>>>>(
186-
std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...)
187-
.get();
188-
} else if constexpr (is_specialization_v<T, std::vector> && has_messageable_value_type<T>::value) {
189-
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
190-
// this catches all std::vector objects with messageable value type before checking if is also
191-
// has a root dictionary, so non-serialized transmission is preferred
192-
using ValueType = typename T::value_type;
193-
194-
// Note: initial payload size is 0 and will be set by the context before sending
195-
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
196-
return context.add<MessageContext::VectorObject<ValueType>>(std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...).get();
197-
} else if constexpr (has_root_dictionary<T>::value == true && is_messageable<T>::value == false) {
198-
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
199-
// Extended support for types implementing the Root ClassDef interface, both TObject
200-
// derived types and others
201-
if constexpr (enable_root_serialization<T>::value) {
202-
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodROOT, 0);
203-
204-
return context.add<typename enable_root_serialization<T>::object_type>(std::move(headerMessage), routeIndex, std::forward<Args>(args)...).get();
205-
} else {
206-
static_assert(enable_root_serialization<T>::value, "Please make sure you include RootMessageContext.h");
207-
}
208-
// Note: initial payload size is 0 and will be set by the context before sending
209-
} else if constexpr (std::is_base_of_v<std::string, T>) {
210-
auto* s = new std::string(args...);
211-
adopt(spec, s);
212-
return *s;
213-
} else if constexpr (requires { static_cast<struct TableBuilder>(std::declval<std::decay_t<T>>()); }) {
214-
auto tb = std::move(LifetimeHolder<TableBuilder>(new std::decay_t<T>(args...)));
215-
adopt(spec, tb);
216-
return tb;
217-
} else if constexpr (requires { static_cast<struct TreeToTable>(std::declval<std::decay_t<T>>()); }) {
218-
auto t2t = std::move(LifetimeHolder<TreeToTable>(new std::decay_t<T>(args...)));
219-
adopt(spec, t2t);
220-
return t2t;
221-
} else if constexpr (sizeof...(Args) == 0) {
222-
if constexpr (is_messageable<T>::value == true) {
223-
return *reinterpret_cast<T*>(newChunk(spec, sizeof(T)).data());
224-
} else {
225-
static_assert(always_static_assert_v<T>, ERROR_STRING);
226-
}
227-
} else if constexpr (sizeof...(Args) == 1) {
228-
using FirstArg = typename std::tuple_element<0, std::tuple<Args...>>::type;
229-
if constexpr (std::is_integral_v<FirstArg>) {
230-
if constexpr (is_messageable<T>::value == true) {
231-
auto [nElements] = std::make_tuple(args...);
232-
auto size = nElements * sizeof(T);
233-
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
234-
235-
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, size);
236-
return context.add<MessageContext::SpanObject<T>>(std::move(headerMessage), routeIndex, 0, nElements).get();
237-
}
238-
} else if constexpr (std::is_same_v<FirstArg, std::shared_ptr<arrow::Schema>>) {
239-
if constexpr (std::is_base_of_v<arrow::ipc::RecordBatchWriter, T>) {
240-
auto [schema] = std::make_tuple(args...);
241-
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;
242-
create(spec, &writer, schema);
243-
return writer;
244-
}
245-
} else {
246-
static_assert(always_static_assert_v<T>, ERROR_STRING);
247-
}
171+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
172+
// plain buffer as polymorphic spectator std::vector, which does not run constructors / destructors
173+
using ValueType = typename T::value_type;
174+
175+
// Note: initial payload size is 0 and will be set by the context before sending
176+
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
177+
return context.add<MessageContext::VectorObject<ValueType, MessageContext::ContainerRefObject<std::vector<ValueType, o2::pmr::NoConstructAllocator<ValueType>>>>>(
178+
std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...)
179+
.get();
180+
}
181+
182+
template <typename T, typename... Args>
183+
requires VectorOfMessageableTypes<T>
184+
decltype(auto) make(const Output& spec, Args... args)
185+
{
186+
auto& timingInfo = mRegistry.get<TimingInfo>();
187+
auto& context = mRegistry.get<MessageContext>();
188+
189+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
190+
// this catches all std::vector objects with messageable value type before checking if is also
191+
// has a root dictionary, so non-serialized transmission is preferred
192+
using ValueType = typename T::value_type;
193+
194+
// Note: initial payload size is 0 and will be set by the context before sending
195+
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
196+
return context.add<MessageContext::VectorObject<ValueType>>(std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...).get();
197+
}
198+
199+
template <typename T, typename... Args>
200+
requires(!VectorOfMessageableTypes<T> && has_root_dictionary<T>::value == true && is_messageable<T>::value == false)
201+
decltype(auto) make(const Output& spec, Args... args)
202+
{
203+
auto& timingInfo = mRegistry.get<TimingInfo>();
204+
auto& context = mRegistry.get<MessageContext>();
205+
206+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
207+
// Extended support for types implementing the Root ClassDef interface, both TObject
208+
// derived types and others
209+
if constexpr (enable_root_serialization<T>::value) {
210+
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodROOT, 0);
211+
212+
return context.add<typename enable_root_serialization<T>::object_type>(std::move(headerMessage), routeIndex, std::forward<Args>(args)...).get();
248213
} else {
249-
static_assert(always_static_assert_v<T>, ERROR_STRING);
214+
static_assert(enable_root_serialization<T>::value, "Please make sure you include RootMessageContext.h");
250215
}
251216
}
252217

218+
template <typename T, typename... Args>
219+
requires std::is_base_of_v<std::string, T>
220+
decltype(auto) make(const Output& spec, Args... args)
221+
{
222+
auto* s = new std::string(args...);
223+
adopt(spec, s);
224+
return *s;
225+
}
226+
227+
template <typename T, typename... Args>
228+
requires(requires { static_cast<struct TableBuilder>(std::declval<std::decay_t<T>>()); })
229+
decltype(auto) make(const Output& spec, Args... args)
230+
{
231+
auto tb = std::move(LifetimeHolder<TableBuilder>(new std::decay_t<T>(args...)));
232+
adopt(spec, tb);
233+
return tb;
234+
}
235+
236+
template <typename T, typename... Args>
237+
requires(requires { static_cast<struct TreeToTable>(std::declval<std::decay_t<T>>()); })
238+
decltype(auto) make(const Output& spec, Args... args)
239+
{
240+
auto t2t = std::move(LifetimeHolder<TreeToTable>(new std::decay_t<T>(args...)));
241+
adopt(spec, t2t);
242+
return t2t;
243+
}
244+
245+
template <typename T>
246+
requires is_messageable<T>::value && (!is_specialization_v<T, UninitializedVector>)
247+
decltype(auto) make(const Output& spec)
248+
{
249+
return *reinterpret_cast<T*>(newChunk(spec, sizeof(T)).data());
250+
}
251+
252+
template <typename T>
253+
requires is_messageable<T>::value && (!is_specialization_v<T, UninitializedVector>)
254+
decltype(auto) make(const Output& spec, std::integral auto nElements)
255+
{
256+
auto& timingInfo = mRegistry.get<TimingInfo>();
257+
auto& context = mRegistry.get<MessageContext>();
258+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
259+
260+
fair::mq::MessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, nElements * sizeof(T));
261+
return context.add<MessageContext::SpanObject<T>>(std::move(headerMessage), routeIndex, 0, nElements).get();
262+
}
263+
264+
template <typename T, typename Arg>
265+
decltype(auto) make(const Output& spec, std::same_as<std::shared_ptr<arrow::Schema>> auto schema)
266+
{
267+
std::shared_ptr<arrow::ipc::RecordBatchWriter> writer;
268+
create(spec, &writer, schema);
269+
return writer;
270+
}
271+
253272
/// Adopt a string in the framework and serialize / send
254273
/// it to the consumers of @a spec once done.
255274
void

0 commit comments

Comments
 (0)