Skip to content

Commit d6d6045

Browse files
authored
Track Fit/State SoA, main branch (2025.08.05.) (#1112)
* Introduced an SoA EDM for the track fit in traccc::core. * Adapted traccc::device_common and traccc::cuda to the SoA track fit EDM. * Adapted traccc::alpaka and traccc::sycl to the SoA track fit EDM. * Adapted the tests and benchmarks to the SoA track fit EDM. * Adapted the examples to the SoA track fit EDM. * Introduced track_candidates_container::device and track_fit_container::device. To make it a bit easier to deal with the slightly complex EDMs of track finding and fitting in the code. * Moved "complex EDM operations" to helper functions. * Adjusted the host Kalman fit testing code a little. Taking into account that the updated code now returns tracks that have not been successfully fitted. Previously it would not return the tracks for which the fit failed. * Reintroduce print_fitted_tracks_statistics(...). But now as a function in the examples code. Not as part of traccc::core. * Small changes in the host fitting code. Based on the PR feedback. * Made the plotting tool documentation more explicit.
1 parent fcd75b7 commit d6d6045

File tree

94 files changed

+2048
-1072
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+2048
-1072
lines changed

benchmarks/cuda/toy_detector_cuda.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,6 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
7373
// Detector view object
7474
auto det_view = detray::get_data(det_buffer);
7575

76-
// D2H copy object
77-
traccc::device::container_d2h_copy_alg<traccc::track_state_container_types>
78-
track_state_d2h{{device_mr, &host_mr}, copy};
79-
8076
for (auto _ : state) {
8177

8278
state.PauseTiming();
@@ -129,7 +125,7 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
129125
params_cuda_buffer);
130126

131127
// Run track fitting
132-
traccc::track_state_container_types::buffer
128+
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
133129
track_states_cuda_buffer = device_fitting(
134130
det_view, field,
135131
{track_candidates_cuda_buffer, measurements_cuda_buffer});

core/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ traccc_add_library( traccc_core core TYPE SHARED
2121
"include/traccc/edm/details/device_container.hpp"
2222
"include/traccc/edm/details/host_container.hpp"
2323
"include/traccc/edm/measurement.hpp"
24+
"include/traccc/edm/measurement_helpers.hpp"
25+
"include/traccc/edm/impl/measurement_helpers.ipp"
2426
"include/traccc/edm/particle.hpp"
2527
"include/traccc/edm/track_parameters.hpp"
2628
"include/traccc/edm/container.hpp"
27-
"include/traccc/edm/track_state.hpp"
2829
"include/traccc/edm/silicon_cell_collection.hpp"
2930
"include/traccc/edm/impl/silicon_cell_collection.ipp"
3031
"include/traccc/edm/silicon_cluster_collection.hpp"
@@ -35,6 +36,12 @@ traccc_add_library( traccc_core core TYPE SHARED
3536
"include/traccc/edm/track_candidate_collection.hpp"
3637
"include/traccc/edm/impl/track_candidate_collection.ipp"
3738
"include/traccc/edm/track_candidate_container.hpp"
39+
"include/traccc/edm/track_state_collection.hpp"
40+
"include/traccc/edm/impl/track_state_collection.ipp"
41+
"include/traccc/edm/track_state_helpers.hpp"
42+
"include/traccc/edm/impl/track_state_helpers.ipp"
43+
"include/traccc/edm/track_fit_outcome.hpp"
44+
"include/traccc/edm/track_fit_collection.hpp"
3845
# Magnetic field description.
3946
"include/traccc/bfield/magnetic_field_types.hpp"
4047
"include/traccc/bfield/magnetic_field.hpp"
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::edm {
11+
12+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
13+
TRACCC_HOST_DEVICE void get_measurement_local(
14+
const measurement& meas, detray::dmatrix<algebra_t, D, 1>& pos) {
15+
16+
static_assert(((D == 1u) || (D == 2u)),
17+
"The measurement dimension must be 1 or 2");
18+
19+
assert((meas.subs.get_indices()[0] == e_bound_loc0) ||
20+
(meas.subs.get_indices()[0] == e_bound_loc1));
21+
22+
const point2& local = meas.local;
23+
24+
switch (meas.subs.get_indices()[0]) {
25+
case e_bound_loc0:
26+
getter::element(pos, 0, 0) = local[0];
27+
if constexpr (D == 2u) {
28+
getter::element(pos, 1, 0) = local[1];
29+
}
30+
break;
31+
case e_bound_loc1:
32+
getter::element(pos, 0, 0) = local[1];
33+
if constexpr (D == 2u) {
34+
getter::element(pos, 1, 0) = local[0];
35+
}
36+
break;
37+
default:
38+
#if defined(__GNUC__)
39+
__builtin_unreachable();
40+
#endif
41+
}
42+
}
43+
44+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
45+
TRACCC_HOST_DEVICE void get_measurement_covariance(
46+
const measurement& meas, detray::dmatrix<algebra_t, D, D>& cov) {
47+
48+
static_assert(((D == 1u) || (D == 2u)),
49+
"The measurement dimension must be 1 or 2");
50+
51+
assert((meas.subs.get_indices()[0] == e_bound_loc0) ||
52+
(meas.subs.get_indices()[0] == e_bound_loc1));
53+
54+
const variance2& variance = meas.variance;
55+
56+
switch (meas.subs.get_indices()[0]) {
57+
case e_bound_loc0:
58+
getter::element(cov, 0, 0) = variance[0];
59+
if constexpr (D == 2u) {
60+
getter::element(cov, 0, 1) = 0.f;
61+
getter::element(cov, 1, 0) = 0.f;
62+
getter::element(cov, 1, 1) = variance[1];
63+
}
64+
break;
65+
case e_bound_loc1:
66+
getter::element(cov, 0, 0) = variance[1];
67+
if constexpr (D == 2u) {
68+
getter::element(cov, 0, 1) = 0.f;
69+
getter::element(cov, 1, 0) = 0.f;
70+
getter::element(cov, 1, 1) = variance[0];
71+
}
72+
break;
73+
default:
74+
#if defined(__GNUC__)
75+
__builtin_unreachable();
76+
#endif
77+
}
78+
}
79+
80+
} // namespace traccc::edm
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2022-2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::edm {
11+
12+
template <typename BASE>
13+
TRACCC_HOST_DEVICE void track_fit<BASE>::reset_quality() {
14+
15+
ndf() = {};
16+
chi2() = {};
17+
pval() = {};
18+
nholes() = {};
19+
}
20+
21+
} // namespace traccc::edm
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2022-2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::edm {
11+
12+
template <typename BASE>
13+
TRACCC_HOST_DEVICE bool track_state<BASE>::is_hole() const {
14+
15+
return (state() & IS_HOLE_MASK);
16+
}
17+
18+
template <typename BASE>
19+
TRACCC_HOST_DEVICE void track_state<BASE>::set_hole(bool value) {
20+
21+
if (value) {
22+
state() |= IS_HOLE_MASK;
23+
} else {
24+
state() &= ~IS_HOLE_MASK;
25+
}
26+
}
27+
28+
template <typename BASE>
29+
TRACCC_HOST_DEVICE bool track_state<BASE>::is_smoothed() const {
30+
31+
return (state() & IS_SMOOTHED_MASK);
32+
}
33+
34+
template <typename BASE>
35+
TRACCC_HOST_DEVICE void track_state<BASE>::set_smoothed(bool value) {
36+
37+
if (value) {
38+
state() |= IS_SMOOTHED_MASK;
39+
} else {
40+
state() &= ~IS_SMOOTHED_MASK;
41+
}
42+
}
43+
44+
} // namespace traccc::edm
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::edm {
11+
12+
template <typename algebra_t>
13+
TRACCC_HOST_DEVICE
14+
typename track_state_collection<algebra_t>::device::object_type
15+
make_track_state(
16+
const measurement_collection_types::const_device& measurements,
17+
unsigned int mindex) {
18+
19+
// Create the result object.
20+
typename track_state_collection<algebra_t>::device::object_type state{
21+
track_state_collection<algebra_t>::device::object_type::IS_HOLE_MASK,
22+
0.f,
23+
0.f,
24+
0.f,
25+
{},
26+
{},
27+
mindex};
28+
29+
// Set the correct surface link for the track parameters.
30+
state.filtered_params().set_surface_link(
31+
measurements.at(mindex).surface_link);
32+
state.smoothed_params().set_surface_link(
33+
measurements.at(mindex).surface_link);
34+
35+
// Return the initialized state.
36+
return state;
37+
}
38+
39+
} // namespace traccc::edm
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Local include(s).
11+
#include "traccc/definitions/primitives.hpp"
12+
#include "traccc/definitions/qualifiers.hpp"
13+
#include "traccc/edm/measurement.hpp"
14+
15+
namespace traccc::edm {
16+
17+
/// Get the local position of a measurement as a matrix
18+
///
19+
/// @tparam algebra_t The algebra type used to describe the tracks
20+
/// @tparam size_t The type of the matrix size variable
21+
/// @tparam D The dimension of the matrix
22+
///
23+
/// @param meas The measurement to extract the local position from
24+
/// @param pos The matrix to fill with the local position of the measurement
25+
///
26+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
27+
TRACCC_HOST_DEVICE void get_measurement_local(
28+
const measurement& meas, detray::dmatrix<algebra_t, D, 1>& pos);
29+
30+
/// Get the covariance of a measurement as a matrix
31+
///
32+
/// @tparam algebra_t The algebra type used to describe the tracks
33+
/// @tparam size_t The type of the matrix size variable
34+
/// @tparam D The dimension of the matrix
35+
///
36+
/// @param meas The measurement to extract the covariance from
37+
/// @param cov The matrix to fill with the covariance of the measurement
38+
///
39+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
40+
TRACCC_HOST_DEVICE void get_measurement_covariance(
41+
const measurement& meas, detray::dmatrix<algebra_t, D, D>& cov);
42+
43+
} // namespace traccc::edm
44+
45+
// Include the implementation.
46+
#include "traccc/edm/impl/measurement_helpers.ipp"

core/include/traccc/edm/track_candidate_container.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Local include(s).
11+
#include "traccc/definitions/qualifiers.hpp"
1112
#include "traccc/edm/measurement.hpp"
1213
#include "traccc/edm/track_candidate_collection.hpp"
1314

@@ -71,6 +72,28 @@ struct track_candidate_container {
7172
measurement_collection_types::const_view measurements;
7273
};
7374

75+
struct device {
76+
/// Constructor from a view
77+
TRACCC_HOST_DEVICE
78+
explicit device(const view& v)
79+
: tracks{v.tracks}, measurements{v.measurements} {}
80+
/// The track candidates
81+
track_candidate_collection<ALGEBRA>::device tracks;
82+
/// Measurements referenced by the tracks
83+
measurement_collection_types::const_device measurements;
84+
};
85+
86+
struct const_device {
87+
/// Constructor from a view
88+
TRACCC_HOST_DEVICE
89+
explicit const_device(const const_view& v)
90+
: tracks{v.tracks}, measurements{v.measurements} {}
91+
/// The track candidates
92+
track_candidate_collection<ALGEBRA>::const_device tracks;
93+
/// Measurements referenced by the tracks
94+
measurement_collection_types::const_device measurements;
95+
};
96+
7497
}; // struct track_candidate_container
7598

7699
} // namespace traccc::edm

0 commit comments

Comments
 (0)