Skip to content

Commit 9f7f77f

Browse files
committed
Adapted the tests and benchmarks to the SoA track fit EDM.
1 parent b81e99b commit 9f7f77f

18 files changed

+150
-130
lines changed

benchmarks/cuda/toy_detector_cuda.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,6 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
7373
// Detector view object
7474
auto det_view = detray::get_data(det_buffer);
7575

76-
// D2H copy object
77-
traccc::device::container_d2h_copy_alg<traccc::track_state_container_types>
78-
track_state_d2h{{device_mr, &host_mr}, copy};
79-
8076
for (auto _ : state) {
8177

8278
state.PauseTiming();
@@ -129,7 +125,7 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
129125
params_cuda_buffer);
130126

131127
// Run track fitting
132-
traccc::track_state_container_types::buffer
128+
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
133129
track_states_cuda_buffer = device_fitting(
134130
det_view, field,
135131
{track_candidates_cuda_buffer, measurements_cuda_buffer});

tests/common/tests/kalman_fitting_momentum_resolution_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
namespace traccc {
2828

2929
void KalmanFittingMomentumResolutionTests::consistency_tests(
30-
const track_state_collection_types::host& track_states_per_track) const {
30+
const edm::track_fit_collection<default_algebra>::host::const_proxy_type&
31+
track,
32+
const edm::track_state_collection<default_algebra>::host&) const {
3133

3234
// The nubmer of track states is supposed be equal to the number
3335
// of planes
34-
ASSERT_EQ(track_states_per_track.size(), std::get<11>(GetParam()));
36+
ASSERT_EQ(track.state_indices().size(), std::get<11>(GetParam()));
3537
}
3638

3739
void KalmanFittingMomentumResolutionTests::momentum_resolution_tests(

tests/common/tests/kalman_fitting_momentum_resolution_test.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ class KalmanFittingMomentumResolutionTests
5757
1.f * traccc::unit<scalar>::ns};
5858

5959
void consistency_tests(
60-
const track_state_collection_types::host& track_states_per_track) const;
60+
const edm::track_fit_collection<
61+
default_algebra>::host::const_proxy_type& track,
62+
const edm::track_state_collection<default_algebra>::host& track_states)
63+
const;
6164

6265
void momentum_resolution_tests(std::string_view file_name) const;
6366

tests/common/tests/kalman_fitting_telescope_test.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,14 @@ class KalmanFittingTelescopeTests
7777
std::get<11>(GetParam()));
7878
}
7979

80-
void consistency_tests(const track_state_collection_types::host&
81-
track_states_per_track) const {
80+
void consistency_tests(
81+
const edm::track_fit_collection<
82+
default_algebra>::host::const_proxy_type& track,
83+
const edm::track_state_collection<default_algebra>::host&) const {
8284

8385
// The nubmer of track states is supposed be equal to the number
8486
// of planes
85-
ASSERT_EQ(track_states_per_track.size(), std::get<11>(GetParam()));
87+
ASSERT_EQ(track.state_indices().size(), std::get<11>(GetParam()));
8688
}
8789

8890
protected:

tests/common/tests/kalman_fitting_test.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,31 +151,51 @@ void KalmanFittingTests::ndf_tests(
151151
}
152152

153153
void KalmanFittingTests::ndf_tests(
154-
const fitting_result<traccc::default_algebra>& fit_res,
155-
const track_state_collection_types::host& track_states_per_track) {
154+
const edm::track_fit_collection<default_algebra>::host::const_proxy_type&
155+
track,
156+
const edm::track_state_collection<default_algebra>::host& track_states,
157+
const measurement_collection_types::host& measurements) {
156158

157159
scalar dim_sum = 0;
158160
std::size_t n_effective_states = 0;
159161

160-
for (const auto& state : track_states_per_track) {
162+
for (unsigned int state_idx : track.state_indices()) {
163+
164+
auto state = track_states.at(state_idx);
161165

162-
if (!state.is_hole && state.is_smoothed) {
166+
if (!state.is_hole() && state.is_smoothed()) {
163167

164-
dim_sum += static_cast<scalar>(state.get_measurement().meas_dim);
168+
dim_sum += static_cast<scalar>(
169+
measurements.at(state.measurement_index()).meas_dim);
165170
n_effective_states++;
166171
}
167172
}
168173

169174
// Check if the number of degree of freedoms is equal to (the sum of
170175
// measurement dimensions - 5)
171-
ASSERT_FLOAT_EQ(static_cast<float>(fit_res.trk_quality.ndf),
176+
ASSERT_FLOAT_EQ(static_cast<float>(track.ndf()),
172177
static_cast<float>(dim_sum) - 5.f);
173178

174179
// The number of track states is supposed to be eqaul to the number
175180
// of measurements unless KF failes in the middle of propagation
176-
if (n_effective_states == track_states_per_track.size()) {
181+
if (n_effective_states == track.state_indices().size()) {
177182
n_success++;
178183
}
179184
}
180185

186+
std::size_t KalmanFittingTests::count_successfully_fitted_tracks(
187+
const edm::track_fit_collection<default_algebra>::host& tracks) const {
188+
189+
const std::size_t n_tracks = tracks.size();
190+
std::size_t n_fitted_tracks = 0u;
191+
192+
for (std::size_t i = 0; i < n_tracks; ++i) {
193+
if (tracks.at(i).fit_outcome() == track_fit_outcome::SUCCESS) {
194+
n_fitted_tracks++;
195+
}
196+
}
197+
198+
return n_fitted_tracks;
199+
}
200+
181201
} // namespace traccc

tests/common/tests/kalman_fitting_test.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "traccc/definitions/common.hpp"
1313
#include "traccc/edm/measurement.hpp"
1414
#include "traccc/edm/track_candidate_collection.hpp"
15+
#include "traccc/edm/track_fit_collection.hpp"
16+
#include "traccc/edm/track_state_collection.hpp"
1517
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
1618
#include "traccc/geometry/detector.hpp"
1719
#include "traccc/simulation/event_generators.hpp"
@@ -79,12 +81,23 @@ class KalmanFittingTests : public testing::Test {
7981

8082
/// Validadte the NDF for track fitting output
8183
///
82-
/// @param fit_res Fitting statistics result of a track
83-
/// @param track_states_per_track Track states of a track
84+
/// @param track Fitting statistics result of a track
85+
/// @param track_states All track states in the event
86+
/// @param measurements All measurements in the event
8487
///
8588
void ndf_tests(
86-
const fitting_result<traccc::default_algebra>& fit_res,
87-
const track_state_collection_types::host& track_states_per_track);
89+
const edm::track_fit_collection<
90+
default_algebra>::host::const_proxy_type& track,
91+
const edm::track_state_collection<default_algebra>::host& track_states,
92+
const measurement_collection_types::host& measurements);
93+
94+
/// Count the number of tracks that were successfully fitted
95+
///
96+
/// @param tracks The track fit collection to count on
97+
/// @return The number of tracks that were successfully fitted
98+
///
99+
std::size_t count_successfully_fitted_tracks(
100+
const edm::track_fit_collection<default_algebra>::host& tracks) const;
88101

89102
// The number of tracks successful with KF
90103
std::size_t n_success{0u};

tests/common/tests/kalman_fitting_wire_chamber_test.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,14 @@ class KalmanFittingWireChamberTests
7070
0.05f / traccc::unit<scalar>::GeV,
7171
1.f * traccc::unit<scalar>::ns};
7272

73-
void consistency_tests(const track_state_collection_types::host&
74-
track_states_per_track) const {
73+
void consistency_tests(
74+
const edm::track_fit_collection<
75+
default_algebra>::host::const_proxy_type& track,
76+
const edm::track_state_collection<default_algebra>::host&) const {
7577

7678
// The nubmer of track states is supposed be greater than or
7779
// equal to the number of layers
78-
ASSERT_GE(track_states_per_track.size(), n_wire_layers);
80+
ASSERT_GE(track.state_indices().size(), n_wire_layers);
7981
}
8082

8183
protected:

tests/cpu/test_ckf_sparse_tracks_telescope.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,22 @@ TEST_P(CkfSparseTrackTelescopeTests, Run) {
181181
{vecmem::get_data(track_candidates),
182182
vecmem::get_data(measurements_per_event)});
183183
const std::size_t n_fitted_tracks =
184-
count_successfully_fitted_tracks(track_states);
184+
count_successfully_fitted_tracks(track_states.tracks);
185185

186-
ASSERT_EQ(track_states.size(), n_truth_tracks);
187-
ASSERT_EQ(track_states.size(), n_fitted_tracks);
186+
ASSERT_EQ(track_states.tracks.size(), n_truth_tracks);
187+
ASSERT_EQ(track_states.tracks.size(), n_fitted_tracks);
188188

189189
for (unsigned int i_trk = 0; i_trk < n_truth_tracks; i_trk++) {
190190

191-
const auto& track_states_per_track = track_states[i_trk].items;
192-
const auto& fit_res = track_states[i_trk].header;
191+
consistency_tests(track_states.tracks.at(i_trk),
192+
track_states.states);
193193

194-
consistency_tests(track_states_per_track);
194+
ndf_tests(track_states.tracks.at(i_trk), track_states.states,
195+
measurements_per_event);
195196

196-
ndf_tests(fit_res, track_states_per_track);
197-
198-
fit_performance_writer.write(track_states_per_track, fit_res,
199-
host_det, evt_data);
197+
fit_performance_writer.write(
198+
track_states.tracks.at(i_trk), track_states.states,
199+
measurements_per_event, host_det, evt_data);
200200
}
201201
}
202202

tests/cpu/test_kalman_fitter_hole_count.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
// Project include(s).
99
#include "traccc/bfield/construct_const_bfield.hpp"
10-
#include "traccc/edm/track_state.hpp"
1110
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
1211
#include "traccc/io/utils.hpp"
1312
#include "traccc/resolution/fitting_performance_writer.hpp"
@@ -161,19 +160,19 @@ TEST_P(KalmanFittingHoleCountTests, Run) {
161160
vecmem::get_data(track_candidates.measurements)});
162161

163162
// A sanity check
164-
const std::size_t n_tracks = track_states.size();
163+
const std::size_t n_tracks = track_states.tracks.size();
165164
ASSERT_EQ(n_tracks, n_truth_tracks);
166165

167166
// Check the number of holes
168167
// The three holes at the end are not counted as KF aborts once it goes
169168
// through all track candidates
170-
const auto& fit_res = track_states.at(0u).header;
171-
ASSERT_EQ(fit_res.trk_quality.n_holes, 5u);
169+
const auto track = track_states.tracks.at(0u);
170+
ASSERT_EQ(track.nholes(), 5u);
172171

173172
// Some sanity checks
174173
ASSERT_FLOAT_EQ(
175-
static_cast<float>(fit_res.trk_quality.ndf),
176-
static_cast<float>(track_states.at(0u).items.size()) * 2.f - 5.f);
174+
static_cast<float>(track.ndf()),
175+
static_cast<float>(track.state_indices().size()) * 2.f - 5.f);
177176
}
178177

179178
INSTANTIATE_TEST_SUITE_P(

tests/cpu/test_kalman_fitter_momentum_resolution.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
// Project include(s).
99
#include "traccc/bfield/construct_const_bfield.hpp"
10-
#include "traccc/edm/track_state.hpp"
1110
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
1211
#include "traccc/io/utils.hpp"
1312
#include "traccc/resolution/fitting_performance_writer.hpp"
@@ -173,9 +172,9 @@ TEST_P(KalmanFittingMomentumResolutionTests, Run) {
173172
vecmem::get_data(track_candidates.measurements)});
174173

175174
// Iterator over tracks
176-
const std::size_t n_tracks = track_states.size();
175+
const std::size_t n_tracks = track_states.tracks.size();
177176
const std::size_t n_fitted_tracks =
178-
count_successfully_fitted_tracks(track_states);
177+
count_successfully_fitted_tracks(track_states.tracks);
179178

180179
// n_trakcs = 100
181180
ASSERT_GE(static_cast<float>(n_tracks),
@@ -185,17 +184,17 @@ TEST_P(KalmanFittingMomentumResolutionTests, Run) {
185184

186185
for (std::size_t i_trk = 0; i_trk < n_tracks; i_trk++) {
187186

188-
const auto& track_states_per_track = track_states[i_trk].items;
189-
const auto& fit_res = track_states[i_trk].header;
187+
consistency_tests(track_states.tracks.at(i_trk),
188+
track_states.states);
190189

191-
consistency_tests(track_states_per_track);
190+
ndf_tests(track_states.tracks.at(i_trk), track_states.states,
191+
track_candidates.measurements);
192192

193-
ndf_tests(fit_res, track_states_per_track);
193+
ASSERT_EQ(track_states.tracks.at(i_trk).nholes(), 0u);
194194

195-
ASSERT_EQ(fit_res.trk_quality.n_holes, 0u);
196-
197-
fit_performance_writer.write(track_states_per_track, fit_res,
198-
host_det, evt_data);
195+
fit_performance_writer.write(
196+
track_states.tracks.at(i_trk), track_states.states,
197+
track_candidates.measurements, host_det, evt_data);
199198
}
200199
}
201200

0 commit comments

Comments
 (0)