Skip to content

Commit 1241c09

Browse files
authored
Merge pull request #1068 from stephenswat/refactor/polymorphic_detectors
Implement polymorphic detector storage types
2 parents 1852269 + 98d3cd0 commit 1241c09

File tree

120 files changed

+1444
-1382
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

120 files changed

+1444
-1382
lines changed

benchmarks/cpu/toy_detector_cpu.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,8 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
3838
// VecMem copy object
3939
vecmem::copy copy;
4040

41-
// Type declarations
42-
using host_detector_type = traccc::default_detector::host;
43-
4441
// Read back detector file
45-
host_detector_type det{host_mr};
42+
traccc::host_detector det;
4643
traccc::io::read_detector(
4744
det, host_mr, sim_dir + "toy_detector_geometry.json",
4845
sim_dir + "toy_detector_homogeneous_material.json",

benchmarks/cuda/toy_detector_cuda.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
4949
vecmem::cuda::async_copy async_copy{stream.cudaStream()};
5050

5151
// Read back detector file
52-
traccc::default_detector::host det{cuda_host_mr};
52+
traccc::host_detector det;
5353
traccc::io::read_detector(
5454
det, cuda_host_mr, sim_dir + "toy_detector_geometry.json",
5555
sim_dir + "toy_detector_homogeneous_material.json",
@@ -68,9 +68,8 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
6868
async_copy, stream);
6969

7070
// Copy detector to device
71-
const auto det_buffer = detray::get_buffer(det, device_mr, copy);
72-
// Detector view object
73-
auto det_view = detray::get_data(det_buffer);
71+
const traccc::detector_buffer det_buffer =
72+
traccc::buffer_from_host_detector(det, device_mr, copy);
7473

7574
for (auto _ : state) {
7675

@@ -120,13 +119,13 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
120119
// Run CKF track finding
121120
traccc::edm::track_candidate_collection<
122121
traccc::default_algebra>::buffer track_candidates_cuda_buffer =
123-
device_finding(det_view, field, measurements_cuda_buffer,
122+
device_finding(det_buffer, field, measurements_cuda_buffer,
124123
params_cuda_buffer);
125124

126125
// Run track fitting
127126
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
128127
track_states_cuda_buffer = device_fitting(
129-
det_view, field,
128+
det_buffer, field,
130129
{track_candidates_cuda_buffer, measurements_cuda_buffer});
131130

132131
// Create a temporary buffer that will receive the device memory.

core/CMakeLists.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ traccc_add_library( traccc_core core TYPE SHARED
8686
"include/traccc/finding/details/combinatorial_kalman_filter.hpp"
8787
"include/traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
8888
"src/finding/combinatorial_kalman_filter_algorithm.cpp"
89-
"src/finding/combinatorial_kalman_filter_algorithm_default_detector.cpp"
90-
"src/finding/combinatorial_kalman_filter_algorithm_telescope_detector.cpp"
9189
# Fitting algorithmic code
9290
"include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
9391
"include/traccc/fitting/kalman_filter/kalman_actor.hpp"
@@ -99,8 +97,6 @@ traccc_add_library( traccc_core core TYPE SHARED
9997
"include/traccc/fitting/details/kalman_fitting.hpp"
10098
"include/traccc/fitting/kalman_fitting_algorithm.hpp"
10199
"src/fitting/kalman_fitting_algorithm.cpp"
102-
"src/fitting/kalman_fitting_algorithm_default_detector.cpp"
103-
"src/fitting/kalman_fitting_algorithm_telescope_detector.cpp"
104100
# Seed finding algorithmic code.
105101
"include/traccc/seeding/detail/lin_circle.hpp"
106102
"include/traccc/seeding/detail/doublet.hpp"
@@ -130,8 +126,6 @@ traccc_add_library( traccc_core core TYPE SHARED
130126
"src/seeding/silicon_pixel_spacepoint_formation.hpp"
131127
"include/traccc/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
132128
"src/seeding/silicon_pixel_spacepoint_formation_algorithm.cpp"
133-
"src/seeding/silicon_pixel_spacepoint_formation_algorithm_defdet.cpp"
134-
"src/seeding/silicon_pixel_spacepoint_formation_algorithm_teldet.cpp"
135129
# Ambiguity resolution
136130
"include/traccc/ambiguity_resolution/ambiguity_resolution_config.hpp"
137131
"include/traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"

core/include/traccc/finding/combinatorial_kalman_filter_algorithm.hpp

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "traccc/edm/track_parameters.hpp"
1515
#include "traccc/finding/finding_config.hpp"
1616
#include "traccc/geometry/detector.hpp"
17+
#include "traccc/geometry/host_detector.hpp"
1718
#include "traccc/utils/algorithm.hpp"
1819
#include "traccc/utils/messaging.hpp"
1920

@@ -32,11 +33,7 @@ namespace traccc::host {
3233
///
3334
class combinatorial_kalman_filter_algorithm
3435
: public algorithm<edm::track_candidate_collection<default_algebra>::host(
35-
const default_detector::host&, const magnetic_field&,
36-
const measurement_collection_types::const_view&,
37-
const bound_track_parameters_collection_types::const_view&)>,
38-
public algorithm<edm::track_candidate_collection<default_algebra>::host(
39-
const telescope_detector::host&, const magnetic_field&,
36+
const host_detector&, const magnetic_field&,
4037
const measurement_collection_types::const_view&,
4138
const bound_track_parameters_collection_types::const_view&)>,
4239
public messaging {
@@ -54,23 +51,7 @@ class combinatorial_kalman_filter_algorithm
5451

5552
/// Execute the algorithm
5653
///
57-
/// @param det The (default) detector object
58-
/// @param bfield The magnetic field object
59-
/// @param measurements All measurements in an event
60-
/// @param seeds All seeds in an event to start the track finding
61-
/// with
62-
///
63-
/// @return A container of the found track candidates
64-
///
65-
output_type operator()(
66-
const default_detector::host& det, const magnetic_field& bfield,
67-
const measurement_collection_types::const_view& measurements,
68-
const bound_track_parameters_collection_types::const_view& seeds)
69-
const override;
70-
71-
/// Execute the algorithm
72-
///
73-
/// @param det The (telescope) detector object
54+
/// @param det The detector object
7455
/// @param bfield The magnetic field object
7556
/// @param measurements All measurements in an event
7657
/// @param seeds All seeds in an event to start the track finding
@@ -79,7 +60,7 @@ class combinatorial_kalman_filter_algorithm
7960
/// @return A container of the found track candidates
8061
///
8162
output_type operator()(
82-
const telescope_detector::host& det, const magnetic_field& bfield,
63+
const host_detector& det, const magnetic_field& bfield,
8364
const measurement_collection_types::const_view& measurements,
8465
const bound_track_parameters_collection_types::const_view& seeds)
8566
const override;

core/include/traccc/fitting/kalman_fitting_algorithm.hpp

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "traccc/edm/track_fit_container.hpp"
1515
#include "traccc/fitting/fitting_config.hpp"
1616
#include "traccc/geometry/detector.hpp"
17+
#include "traccc/geometry/host_detector.hpp"
1718
#include "traccc/utils/algorithm.hpp"
1819
#include "traccc/utils/messaging.hpp"
1920

@@ -32,10 +33,7 @@ namespace traccc::host {
3233
/// Kalman filter based track fitting algorithm
3334
class kalman_fitting_algorithm
3435
: public algorithm<edm::track_fit_container<default_algebra>::host(
35-
const default_detector::host&, const magnetic_field&,
36-
const edm::track_candidate_container<default_algebra>::const_view&)>,
37-
public algorithm<edm::track_fit_container<default_algebra>::host(
38-
const telescope_detector::host&, const magnetic_field&,
36+
const host_detector&, const magnetic_field&,
3937
const edm::track_candidate_container<default_algebra>::const_view&)>,
4038
public messaging {
4139

@@ -56,27 +54,14 @@ class kalman_fitting_algorithm
5654

5755
/// Execute the algorithm
5856
///
59-
/// @param det The (default) detector object
60-
/// @param bfield The magnetic field object
61-
/// @param track_candidates All track candidates to fit
62-
///
63-
/// @return A container of the fitted track states
64-
///
65-
output_type operator()(
66-
const default_detector::host& det, const magnetic_field& bfield,
67-
const edm::track_candidate_container<default_algebra>::const_view&
68-
track_candidates) const override;
69-
70-
/// Execute the algorithm
71-
///
72-
/// @param det The (telescope) detector object
57+
/// @param det The detector object
7358
/// @param bfield The magnetic field object
7459
/// @param track_candidates All track candidates to fit
7560
///
7661
/// @return A container of the fitted track states
7762
///
7863
output_type operator()(
79-
const telescope_detector::host& det, const magnetic_field& bfield,
64+
const host_detector& det, const magnetic_field& bfield,
8065
const edm::track_candidate_container<default_algebra>::const_view&
8166
track_candidates) const override;
8267

core/include/traccc/geometry/detector.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct device_detector_container_types {
4747

4848
/// Base struct for the different detector types supported by the project.
4949
template <typename metadata_t>
50-
struct detector {
50+
struct detector_traits {
5151

5252
/// Metadata type of the detector.
5353
using metadata_type = metadata_t;
@@ -66,15 +66,26 @@ struct detector {
6666

6767
}; // struct default_detector
6868

69+
template <typename T>
70+
concept is_detector_traits = requires {
71+
typename T::metadata_type;
72+
typename T::host;
73+
typename T::device;
74+
typename T::view;
75+
typename T::buffer;
76+
};
77+
6978
/// Default detector (also used for ODD)
7079
using default_detector =
71-
detector<detray::default_metadata<traccc::default_algebra>>;
80+
detector_traits<detray::default_metadata<traccc::default_algebra>>;
7281

7382
/// Telescope detector
74-
using telescope_detector = detector<
83+
using telescope_detector = detector_traits<
7584
detray::telescope_metadata<traccc::default_algebra, detray::rectangle2D>>;
7685

7786
/// Toy detector
78-
using toy_detector = detector<detray::toy_metadata<traccc::default_algebra>>;
87+
using toy_detector =
88+
detector_traits<detray::toy_metadata<traccc::default_algebra>>;
7989

90+
using detector_type_list = std::tuple<default_detector, telescope_detector>;
8091
} // namespace traccc
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s).
11+
#include "traccc/geometry/detector.hpp"
12+
#include "traccc/geometry/host_detector.hpp"
13+
#include "traccc/geometry/move_only_any.hpp"
14+
15+
// Detray include(s).
16+
#include <any>
17+
#include <detray/core/detector.hpp>
18+
19+
namespace traccc {
20+
21+
class detector_buffer {
22+
public:
23+
template <typename detector_traits_t>
24+
void set(typename detector_traits_t::buffer&& obj)
25+
requires(is_detector_traits<detector_traits_t>)
26+
{
27+
m_obj.set<typename detector_traits_t::buffer>(std::move(obj));
28+
}
29+
30+
template <typename detector_traits_t>
31+
bool is() const
32+
requires(is_detector_traits<detector_traits_t>)
33+
{
34+
return (type() == typeid(typename detector_traits_t::buffer));
35+
}
36+
37+
const std::type_info& type() const { return m_obj.type(); }
38+
39+
template <typename detector_traits_t>
40+
const typename detector_traits_t::buffer& as() const
41+
requires(is_detector_traits<detector_traits_t>)
42+
{
43+
return m_obj.as<typename detector_traits_t::buffer>();
44+
}
45+
46+
template <typename detector_traits_t>
47+
typename detector_traits_t::view as_view() const
48+
requires(is_detector_traits<detector_traits_t>)
49+
{
50+
return detray::get_data(as<detector_traits_t>());
51+
}
52+
53+
private:
54+
move_only_any m_obj;
55+
}; // class bfield
56+
57+
/// @brief Helper function for `detector_buffer_visitor`
58+
template <typename callable_t, typename detector_t, typename... detector_ts>
59+
auto detector_buffer_visitor_helper(const detector_buffer& detector_buffer,
60+
callable_t&& callable,
61+
std::tuple<detector_t, detector_ts...>*) {
62+
if (detector_buffer.is<detector_t>()) {
63+
return callable.template operator()<detector_t>(
64+
detector_buffer.as_view<detector_t>());
65+
} else {
66+
if constexpr (sizeof...(detector_ts) > 0) {
67+
return detector_buffer_visitor_helper(
68+
detector_buffer, std::forward<callable_t>(callable),
69+
static_cast<std::tuple<detector_ts...>*>(nullptr));
70+
} else {
71+
std::stringstream exception_message;
72+
73+
exception_message
74+
<< "Invalid detector type (" << detector_buffer.type().name()
75+
<< ") received, but this type is not supported" << std::endl;
76+
77+
throw std::invalid_argument(exception_message.str());
78+
}
79+
}
80+
}
81+
82+
/// @brief Visitor for polymorphic detector buffer types
83+
///
84+
/// This function takes a list of supported detector trait types and checks
85+
/// if the provided field is one of them. If it is, it will call the provided
86+
/// callable on a view of it and otherwise it will throw an exception.
87+
template <typename detector_buffer_list_t, typename callable_t>
88+
auto detector_buffer_visitor(const detector_buffer& detector_buffer,
89+
callable_t&& callable) {
90+
return detector_buffer_visitor_helper(
91+
detector_buffer, std::forward<callable_t>(callable),
92+
static_cast<detector_buffer_list_t*>(nullptr));
93+
}
94+
95+
// TODO: Docs
96+
inline detector_buffer buffer_from_host_detector(const host_detector& det,
97+
vecmem::memory_resource& mr,
98+
vecmem::copy& copy) {
99+
return host_detector_visitor<traccc::detector_type_list>(
100+
det, [&mr, &copy]<typename detector_traits_t>(
101+
const typename detector_traits_t::host& detector) {
102+
traccc::detector_buffer rv;
103+
rv.set<detector_traits_t>(detray::get_buffer(detector, mr, copy));
104+
return rv;
105+
});
106+
}
107+
108+
} // namespace traccc

0 commit comments

Comments
 (0)