Skip to content

Commit e204f3d

Browse files
committed
Moved "complex EDM operations" to helper functions.
1 parent a4a6900 commit e204f3d

File tree

13 files changed

+225
-152
lines changed

13 files changed

+225
-152
lines changed

core/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ traccc_add_library( traccc_core core TYPE SHARED
2121
"include/traccc/edm/details/device_container.hpp"
2222
"include/traccc/edm/details/host_container.hpp"
2323
"include/traccc/edm/measurement.hpp"
24+
"include/traccc/edm/measurement_helpers.hpp"
25+
"include/traccc/edm/impl/measurement_helpers.ipp"
2426
"include/traccc/edm/particle.hpp"
2527
"include/traccc/edm/track_parameters.hpp"
2628
"include/traccc/edm/container.hpp"
@@ -36,6 +38,8 @@ traccc_add_library( traccc_core core TYPE SHARED
3638
"include/traccc/edm/track_candidate_container.hpp"
3739
"include/traccc/edm/track_state_collection.hpp"
3840
"include/traccc/edm/impl/track_state_collection.ipp"
41+
"include/traccc/edm/track_state_helpers.hpp"
42+
"include/traccc/edm/impl/track_state_helpers.ipp"
3943
"include/traccc/edm/track_fit_outcome.hpp"
4044
"include/traccc/edm/track_fit_collection.hpp"
4145
# Magnetic field description.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::edm {
11+
12+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
13+
TRACCC_HOST_DEVICE void get_measurement_local(
14+
const measurement& meas, detray::dmatrix<algebra_t, D, 1>& pos) {
15+
16+
static_assert(((D == 1u) || (D == 2u)),
17+
"The measurement dimension must be 1 or 2");
18+
19+
assert((meas.subs.get_indices()[0] == e_bound_loc0) ||
20+
(meas.subs.get_indices()[0] == e_bound_loc1));
21+
22+
const point2& local = meas.local;
23+
24+
switch (meas.subs.get_indices()[0]) {
25+
case e_bound_loc0:
26+
getter::element(pos, 0, 0) = local[0];
27+
if constexpr (D == 2u) {
28+
getter::element(pos, 1, 0) = local[1];
29+
}
30+
break;
31+
case e_bound_loc1:
32+
getter::element(pos, 0, 0) = local[1];
33+
if constexpr (D == 2u) {
34+
getter::element(pos, 1, 0) = local[0];
35+
}
36+
break;
37+
default:
38+
#if defined(__GNUC__)
39+
__builtin_unreachable();
40+
#endif
41+
}
42+
}
43+
44+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
45+
TRACCC_HOST_DEVICE void get_measurement_covariance(
46+
const measurement& meas, detray::dmatrix<algebra_t, D, D>& cov) {
47+
48+
static_assert(((D == 1u) || (D == 2u)),
49+
"The measurement dimension must be 1 or 2");
50+
51+
assert((meas.subs.get_indices()[0] == e_bound_loc0) ||
52+
(meas.subs.get_indices()[0] == e_bound_loc1));
53+
54+
const variance2& variance = meas.variance;
55+
56+
switch (meas.subs.get_indices()[0]) {
57+
case e_bound_loc0:
58+
getter::element(cov, 0, 0) = variance[0];
59+
if constexpr (D == 2u) {
60+
getter::element(cov, 0, 1) = 0.f;
61+
getter::element(cov, 1, 0) = 0.f;
62+
getter::element(cov, 1, 1) = variance[1];
63+
}
64+
break;
65+
case e_bound_loc1:
66+
getter::element(cov, 0, 0) = variance[1];
67+
if constexpr (D == 2u) {
68+
getter::element(cov, 0, 1) = 0.f;
69+
getter::element(cov, 1, 0) = 0.f;
70+
getter::element(cov, 1, 1) = variance[0];
71+
}
72+
break;
73+
default:
74+
#if defined(__GNUC__)
75+
__builtin_unreachable();
76+
#endif
77+
}
78+
}
79+
80+
} // namespace traccc::edm

core/include/traccc/edm/impl/track_state_collection.ipp

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -41,82 +41,4 @@ TRACCC_HOST_DEVICE void track_state<BASE>::set_smoothed(bool value) {
4141
}
4242
}
4343

44-
template <typename BASE>
45-
template <detray::concepts::algebra ALGEBRA, std::integral size_type,
46-
size_type D>
47-
TRACCC_HOST_DEVICE void track_state<BASE>::get_measurement_local(
48-
const measurement_collection_types::const_device& measurements,
49-
detray::dmatrix<ALGEBRA, D, 1>& pos) const {
50-
51-
static_assert(((D == 1u) || (D == 2u)),
52-
"The measurement dimension must be 1 or 2");
53-
54-
assert((measurements.at(measurement_index()).subs.get_indices()[0] ==
55-
e_bound_loc0) ||
56-
(measurements.at(measurement_index()).subs.get_indices()[0] ==
57-
e_bound_loc1));
58-
59-
const point2& local = measurements.at(measurement_index()).local;
60-
61-
switch (measurements.at(measurement_index()).subs.get_indices()[0]) {
62-
case e_bound_loc0:
63-
getter::element(pos, 0, 0) = local[0];
64-
if constexpr (D == 2u) {
65-
getter::element(pos, 1, 0) = local[1];
66-
}
67-
break;
68-
case e_bound_loc1:
69-
getter::element(pos, 0, 0) = local[1];
70-
if constexpr (D == 2u) {
71-
getter::element(pos, 1, 0) = local[0];
72-
}
73-
break;
74-
default:
75-
#if defined(__GNUC__)
76-
__builtin_unreachable();
77-
#endif
78-
}
79-
}
80-
81-
template <typename BASE>
82-
template <detray::concepts::algebra ALGEBRA, std::integral size_type,
83-
size_type D>
84-
TRACCC_HOST_DEVICE void track_state<BASE>::get_measurement_covariance(
85-
const measurement_collection_types::const_device& measurements,
86-
detray::dmatrix<ALGEBRA, D, D>& cov) const {
87-
88-
static_assert(((D == 1u) || (D == 2u)),
89-
"The measurement dimension must be 1 or 2");
90-
91-
assert((measurements.at(measurement_index()).subs.get_indices()[0] ==
92-
e_bound_loc0) ||
93-
(measurements.at(measurement_index()).subs.get_indices()[0] ==
94-
e_bound_loc1));
95-
96-
const variance2& variance = measurements.at(measurement_index()).variance;
97-
98-
switch (measurements.at(measurement_index()).subs.get_indices()[0]) {
99-
case e_bound_loc0:
100-
getter::element(cov, 0, 0) = variance[0];
101-
if constexpr (D == 2u) {
102-
getter::element(cov, 0, 1) = 0.f;
103-
getter::element(cov, 1, 0) = 0.f;
104-
getter::element(cov, 1, 1) = variance[1];
105-
}
106-
break;
107-
case e_bound_loc1:
108-
getter::element(cov, 0, 0) = variance[1];
109-
if constexpr (D == 2u) {
110-
getter::element(cov, 0, 1) = 0.f;
111-
getter::element(cov, 1, 0) = 0.f;
112-
getter::element(cov, 1, 1) = variance[0];
113-
}
114-
break;
115-
default:
116-
#if defined(__GNUC__)
117-
__builtin_unreachable();
118-
#endif
119-
}
120-
}
121-
12244
} // namespace traccc::edm
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::edm {
11+
12+
template <typename algebra_t>
13+
TRACCC_HOST_DEVICE
14+
typename track_state_collection<algebra_t>::device::object_type
15+
make_track_state(
16+
const measurement_collection_types::const_device& measurements,
17+
unsigned int mindex) {
18+
19+
// Create the result object.
20+
typename track_state_collection<algebra_t>::device::object_type state{
21+
track_state_collection<algebra_t>::device::object_type::IS_HOLE_MASK,
22+
0.f,
23+
0.f,
24+
0.f,
25+
{},
26+
{},
27+
mindex};
28+
29+
// Set the correct surface link for the track parameters.
30+
state.filtered_params().set_surface_link(
31+
measurements.at(mindex).surface_link);
32+
state.smoothed_params().set_surface_link(
33+
measurements.at(mindex).surface_link);
34+
35+
// Return the initialized state.
36+
return state;
37+
}
38+
39+
} // namespace traccc::edm
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Local include(s).
11+
#include "traccc/definitions/primitives.hpp"
12+
#include "traccc/definitions/qualifiers.hpp"
13+
#include "traccc/edm/measurement.hpp"
14+
15+
namespace traccc::edm {
16+
17+
/// Get the local position of a measurement as a matrix
18+
///
19+
/// @tparam algebra_t The algebra type used to describe the tracks
20+
/// @tparam size_t The type of the matrix size variable
21+
/// @tparam D The dimension of the matrix
22+
///
23+
/// @param meas The measurement to extract the local position from
24+
/// @param pos The matrix to fill with the local position of the measurement
25+
///
26+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
27+
TRACCC_HOST_DEVICE void get_measurement_local(
28+
const measurement& meas, detray::dmatrix<algebra_t, D, 1>& pos);
29+
30+
/// Get the covariance of a measurement as a matrix
31+
///
32+
/// @tparam algebra_t The algebra type used to describe the tracks
33+
/// @tparam size_t The type of the matrix size variable
34+
/// @tparam D The dimension of the matrix
35+
///
36+
/// @param meas The measurement to extract the covariance from
37+
/// @param cov The matrix to fill with the covariance of the measurement
38+
///
39+
template <detray::concepts::algebra algebra_t, std::integral size_t, size_t D>
40+
TRACCC_HOST_DEVICE void get_measurement_covariance(
41+
const measurement& meas, detray::dmatrix<algebra_t, D, D>& cov);
42+
43+
} // namespace traccc::edm
44+
45+
// Include the implementation.
46+
#include "traccc/edm/impl/measurement_helpers.ipp"

core/include/traccc/edm/track_state_collection.hpp

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -191,42 +191,6 @@ class track_state : public BASE {
191191
TRACCC_HOST_DEVICE
192192
void set_smoothed(bool value = true);
193193

194-
/// Get the local position of the measurement in a matrix
195-
///
196-
/// @note This function must only be used on proxy objects, not on
197-
/// containers!
198-
///
199-
/// @tparam size_type The type of the matrix size variable
200-
/// @tparam ALGEBRA The algebra type used to describe the tracks
201-
/// @tparam D The dimension of the matrix
202-
///
203-
/// @param measurements All measurements in the event
204-
/// @param pos The matrix to fill with the local position of the measurement
205-
///
206-
template <detray::concepts::algebra ALGEBRA, std::integral size_type,
207-
size_type D>
208-
TRACCC_HOST_DEVICE void get_measurement_local(
209-
const measurement_collection_types::const_device& measurements,
210-
detray::dmatrix<ALGEBRA, D, 1>& pos) const;
211-
212-
/// Get the covariance of the measurement in a matrix
213-
///
214-
/// @note This function must only be used on proxy objects, not on
215-
/// containers!
216-
///
217-
/// @tparam size_type The type of the matrix size variable
218-
/// @tparam ALGEBRA The algebra type used to describe the tracks
219-
/// @tparam D The dimension of the matrix
220-
///
221-
/// @param measurements All measurements in the event
222-
/// @param cov The matrix to fill with the covariance of the measurement
223-
///
224-
template <detray::concepts::algebra ALGEBRA, std::integral size_type,
225-
size_type D>
226-
TRACCC_HOST_DEVICE void get_measurement_covariance(
227-
const measurement_collection_types::const_device& measurements,
228-
detray::dmatrix<ALGEBRA, D, D>& cov) const;
229-
230194
/// @}
231195

232196
}; // class track_state
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Local include(s).
11+
#include "traccc/edm/track_state_collection.hpp"
12+
13+
namespace traccc::edm {
14+
15+
/// Create a track state with default values.
16+
///
17+
/// @param measurements The collection of measurements to use for initialization
18+
/// @param mindex The index of the measurement to associate with the state
19+
///
20+
/// @return A track state object initialized with default values
21+
///
22+
template <typename algebra_t>
23+
TRACCC_HOST_DEVICE
24+
typename track_state_collection<algebra_t>::device::object_type
25+
make_track_state(
26+
const measurement_collection_types::const_device& measurements,
27+
unsigned int mindex);
28+
29+
} // namespace traccc::edm
30+
31+
// Include the implementation.
32+
#include "traccc/edm/impl/track_state_helpers.ipp"

core/include/traccc/finding/details/combinatorial_kalman_filter.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// Project include(s).
1010
#include "traccc/edm/measurement.hpp"
1111
#include "traccc/edm/track_candidate_collection.hpp"
12+
#include "traccc/edm/track_state_helpers.hpp"
1213
#include "traccc/finding/actors/ckf_aborter.hpp"
1314
#include "traccc/finding/actors/interaction_register.hpp"
1415
#include "traccc/finding/candidate_link.hpp"
@@ -264,13 +265,8 @@ combinatorial_kalman_filter(
264265
const measurement& meas = measurements.at(item_id);
265266

266267
// Create a standalone track state object.
267-
typename edm::track_state_collection<
268-
algebra_type>::host::object_type trk_state(0u, 0.f, 0.f,
269-
0.f, {}, {},
270-
item_id);
271-
trk_state.set_hole(true);
272-
trk_state.set_smoothed(false);
273-
trk_state.filtered_params().set_surface_link(meas.surface_link);
268+
auto trk_state =
269+
edm::make_track_state<algebra_type>(measurements, item_id);
274270

275271
const bool is_line = sf.template visit_mask<is_line_visitor>();
276272

core/include/traccc/fitting/details/kalman_fitting.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "traccc/edm/track_fit_collection.hpp"
1414
#include "traccc/edm/track_fit_container.hpp"
1515
#include "traccc/edm/track_state_collection.hpp"
16+
#include "traccc/edm/track_state_helpers.hpp"
1617
#include "traccc/fitting/status_codes.hpp"
1718

1819
// VecMem include(s).
@@ -69,15 +70,8 @@ typename edm::track_fit_container<algebra_t>::host kalman_fitting(
6970
track_candidates.measurement_indices().at(i)) {
7071
fitted_track.state_indices().push_back(
7172
static_cast<unsigned int>(result.states.size()));
72-
result.states.push_back(
73-
{0u, 0.f, 0.f, 0.f, {}, {}, measurement_index});
74-
auto state = result.states.at(result.states.size() - 1);
75-
state.set_hole(true);
76-
state.set_smoothed(false);
77-
state.filtered_params().set_surface_link(
78-
measurements.at(measurement_index).surface_link);
79-
state.smoothed_params().set_surface_link(
80-
measurements.at(measurement_index).surface_link);
73+
result.states.push_back(edm::make_track_state<algebra_t>(
74+
measurements, measurement_index));
8175
}
8276

8377
vecmem::data::vector_buffer<detray::geometry::barcode> seqs_buffer{

0 commit comments

Comments
 (0)