|
| 1 | +/** TRACCC library, part of the ACTS project (R&D line) |
| 2 | + * |
| 3 | + * (c) 2024 CERN for the benefit of the ACTS project |
| 4 | + * |
| 5 | + * Mozilla Public License Version 2.0 |
| 6 | + */ |
| 7 | + |
| 8 | +#pragma once |
| 9 | + |
| 10 | +// Project include(s). |
| 11 | +#include "traccc/geometry/detector.hpp" |
| 12 | +#include "traccc/geometry/host_detector.hpp" |
| 13 | +#include "traccc/geometry/move_only_any.hpp" |
| 14 | + |
| 15 | +// Detray include(s). |
| 16 | +#include <any> |
| 17 | +#include <detray/core/detector.hpp> |
| 18 | + |
| 19 | +namespace traccc { |
| 20 | + |
| 21 | +class detector_buffer { |
| 22 | + public: |
| 23 | + template <typename detector_traits_t> |
| 24 | + void set(typename detector_traits_t::buffer&& obj) |
| 25 | + requires(is_detector_traits<detector_traits_t>) |
| 26 | + { |
| 27 | + m_obj.set<typename detector_traits_t::buffer>(std::move(obj)); |
| 28 | + } |
| 29 | + |
| 30 | + template <typename detector_traits_t> |
| 31 | + bool is() const |
| 32 | + requires(is_detector_traits<detector_traits_t>) |
| 33 | + { |
| 34 | + return (type() == typeid(typename detector_traits_t::buffer)); |
| 35 | + } |
| 36 | + |
| 37 | + const std::type_info& type() const { return m_obj.type(); } |
| 38 | + |
| 39 | + template <typename detector_traits_t> |
| 40 | + const typename detector_traits_t::buffer& as() const |
| 41 | + requires(is_detector_traits<detector_traits_t>) |
| 42 | + { |
| 43 | + return m_obj.as<typename detector_traits_t::buffer>(); |
| 44 | + } |
| 45 | + |
| 46 | + template <typename detector_traits_t> |
| 47 | + typename detector_traits_t::view as_view() const |
| 48 | + requires(is_detector_traits<detector_traits_t>) |
| 49 | + { |
| 50 | + return detray::get_data(as<detector_traits_t>()); |
| 51 | + } |
| 52 | + |
| 53 | + private: |
| 54 | + move_only_any m_obj; |
| 55 | +}; // class bfield |
| 56 | + |
| 57 | +/// @brief Helper function for `detector_buffer_visitor` |
| 58 | +template <typename callable_t, typename detector_t, typename... detector_ts> |
| 59 | +auto detector_buffer_visitor_helper(const detector_buffer& detector_buffer, |
| 60 | + callable_t&& callable, |
| 61 | + std::tuple<detector_t, detector_ts...>*) { |
| 62 | + if (detector_buffer.is<detector_t>()) { |
| 63 | + return callable.template operator()<detector_t>( |
| 64 | + detector_buffer.as_view<detector_t>()); |
| 65 | + } else { |
| 66 | + if constexpr (sizeof...(detector_ts) > 0) { |
| 67 | + return detector_buffer_visitor_helper( |
| 68 | + detector_buffer, std::forward<callable_t>(callable), |
| 69 | + static_cast<std::tuple<detector_ts...>*>(nullptr)); |
| 70 | + } else { |
| 71 | + std::stringstream exception_message; |
| 72 | + |
| 73 | + exception_message |
| 74 | + << "Invalid detector type (" << detector_buffer.type().name() |
| 75 | + << ") received, but this type is not supported" << std::endl; |
| 76 | + |
| 77 | + throw std::invalid_argument(exception_message.str()); |
| 78 | + } |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +/// @brief Visitor for polymorphic detector buffer types |
| 83 | +/// |
| 84 | +/// This function takes a list of supported detector trait types and checks |
| 85 | +/// if the provided field is one of them. If it is, it will call the provided |
| 86 | +/// callable on a view of it and otherwise it will throw an exception. |
| 87 | +template <typename detector_buffer_list_t, typename callable_t> |
| 88 | +auto detector_buffer_visitor(const detector_buffer& detector_buffer, |
| 89 | + callable_t&& callable) { |
| 90 | + return detector_buffer_visitor_helper( |
| 91 | + detector_buffer, std::forward<callable_t>(callable), |
| 92 | + static_cast<detector_buffer_list_t*>(nullptr)); |
| 93 | +} |
| 94 | + |
| 95 | +// TODO: Docs |
| 96 | +inline detector_buffer buffer_from_host_detector(const host_detector& det, |
| 97 | + vecmem::memory_resource& mr, |
| 98 | + vecmem::copy& copy) { |
| 99 | + return host_detector_visitor<traccc::detector_type_list>( |
| 100 | + det, [&mr, ©]<typename detector_traits_t>( |
| 101 | + const typename detector_traits_t::host& detector) { |
| 102 | + traccc::detector_buffer rv; |
| 103 | + rv.set<detector_traits_t>(detray::get_buffer(detector, mr, copy)); |
| 104 | + return rv; |
| 105 | + }); |
| 106 | +} |
| 107 | + |
| 108 | +} // namespace traccc |
0 commit comments