Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions benchmarks/cpu/toy_detector_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
// VecMem copy object
vecmem::copy copy;

// Type declarations
using host_detector_type = traccc::default_detector::host;

// Read back detector file
host_detector_type det{host_mr};
traccc::host_detector det;
traccc::io::read_detector(
det, host_mr, sim_dir + "toy_detector_geometry.json",
sim_dir + "toy_detector_homogeneous_material.json",
Expand Down
11 changes: 5 additions & 6 deletions benchmarks/cuda/toy_detector_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
vecmem::cuda::async_copy async_copy{stream.cudaStream()};

// Read back detector file
traccc::default_detector::host det{cuda_host_mr};
traccc::host_detector det;
traccc::io::read_detector(
det, cuda_host_mr, sim_dir + "toy_detector_geometry.json",
sim_dir + "toy_detector_homogeneous_material.json",
Expand All @@ -68,9 +68,8 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
async_copy, stream);

// Copy detector to device
const auto det_buffer = detray::get_buffer(det, device_mr, copy);
// Detector view object
auto det_view = detray::get_data(det_buffer);
const traccc::detector_buffer det_buffer =
traccc::buffer_from_host_detector(det, device_mr, copy);

for (auto _ : state) {

Expand Down Expand Up @@ -120,13 +119,13 @@ BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
// Run CKF track finding
traccc::edm::track_candidate_collection<
traccc::default_algebra>::buffer track_candidates_cuda_buffer =
device_finding(det_view, field, measurements_cuda_buffer,
device_finding(det_buffer, field, measurements_cuda_buffer,
params_cuda_buffer);

// Run track fitting
traccc::edm::track_fit_container<traccc::default_algebra>::buffer
track_states_cuda_buffer = device_fitting(
det_view, field,
det_buffer, field,
{track_candidates_cuda_buffer, measurements_cuda_buffer});

// Create a temporary buffer that will receive the device memory.
Expand Down
6 changes: 0 additions & 6 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ traccc_add_library( traccc_core core TYPE SHARED
"include/traccc/finding/details/combinatorial_kalman_filter.hpp"
"include/traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
"src/finding/combinatorial_kalman_filter_algorithm.cpp"
"src/finding/combinatorial_kalman_filter_algorithm_default_detector.cpp"
"src/finding/combinatorial_kalman_filter_algorithm_telescope_detector.cpp"
# Fitting algorithmic code
"include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
"include/traccc/fitting/kalman_filter/kalman_actor.hpp"
Expand All @@ -99,8 +97,6 @@ traccc_add_library( traccc_core core TYPE SHARED
"include/traccc/fitting/details/kalman_fitting.hpp"
"include/traccc/fitting/kalman_fitting_algorithm.hpp"
"src/fitting/kalman_fitting_algorithm.cpp"
"src/fitting/kalman_fitting_algorithm_default_detector.cpp"
"src/fitting/kalman_fitting_algorithm_telescope_detector.cpp"
# Seed finding algorithmic code.
"include/traccc/seeding/detail/lin_circle.hpp"
"include/traccc/seeding/detail/doublet.hpp"
Expand Down Expand Up @@ -130,8 +126,6 @@ traccc_add_library( traccc_core core TYPE SHARED
"src/seeding/silicon_pixel_spacepoint_formation.hpp"
"include/traccc/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
"src/seeding/silicon_pixel_spacepoint_formation_algorithm.cpp"
"src/seeding/silicon_pixel_spacepoint_formation_algorithm_defdet.cpp"
"src/seeding/silicon_pixel_spacepoint_formation_algorithm_teldet.cpp"
# Ambiguity resolution
"include/traccc/ambiguity_resolution/ambiguity_resolution_config.hpp"
"include/traccc/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "traccc/edm/track_parameters.hpp"
#include "traccc/finding/finding_config.hpp"
#include "traccc/geometry/detector.hpp"
#include "traccc/geometry/host_detector.hpp"
#include "traccc/utils/algorithm.hpp"
#include "traccc/utils/messaging.hpp"

Expand All @@ -32,11 +33,7 @@ namespace traccc::host {
///
class combinatorial_kalman_filter_algorithm
: public algorithm<edm::track_candidate_collection<default_algebra>::host(
const default_detector::host&, const magnetic_field&,
const measurement_collection_types::const_view&,
const bound_track_parameters_collection_types::const_view&)>,
public algorithm<edm::track_candidate_collection<default_algebra>::host(
const telescope_detector::host&, const magnetic_field&,
const host_detector&, const magnetic_field&,
const measurement_collection_types::const_view&,
const bound_track_parameters_collection_types::const_view&)>,
public messaging {
Expand All @@ -54,23 +51,7 @@ class combinatorial_kalman_filter_algorithm

/// Execute the algorithm
///
/// @param det The (default) detector object
/// @param bfield The magnetic field object
/// @param measurements All measurements in an event
/// @param seeds All seeds in an event to start the track finding
/// with
///
/// @return A container of the found track candidates
///
output_type operator()(
const default_detector::host& det, const magnetic_field& bfield,
const measurement_collection_types::const_view& measurements,
const bound_track_parameters_collection_types::const_view& seeds)
const override;

/// Execute the algorithm
///
/// @param det The (telescope) detector object
/// @param det The detector object
/// @param bfield The magnetic field object
/// @param measurements All measurements in an event
/// @param seeds All seeds in an event to start the track finding
Expand All @@ -79,7 +60,7 @@ class combinatorial_kalman_filter_algorithm
/// @return A container of the found track candidates
///
output_type operator()(
const telescope_detector::host& det, const magnetic_field& bfield,
const host_detector& det, const magnetic_field& bfield,
const measurement_collection_types::const_view& measurements,
const bound_track_parameters_collection_types::const_view& seeds)
const override;
Expand Down
23 changes: 4 additions & 19 deletions core/include/traccc/fitting/kalman_fitting_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "traccc/edm/track_fit_container.hpp"
#include "traccc/fitting/fitting_config.hpp"
#include "traccc/geometry/detector.hpp"
#include "traccc/geometry/host_detector.hpp"
#include "traccc/utils/algorithm.hpp"
#include "traccc/utils/messaging.hpp"

Expand All @@ -32,10 +33,7 @@ namespace traccc::host {
/// Kalman filter based track fitting algorithm
class kalman_fitting_algorithm
: public algorithm<edm::track_fit_container<default_algebra>::host(
const default_detector::host&, const magnetic_field&,
const edm::track_candidate_container<default_algebra>::const_view&)>,
public algorithm<edm::track_fit_container<default_algebra>::host(
const telescope_detector::host&, const magnetic_field&,
const host_detector&, const magnetic_field&,
const edm::track_candidate_container<default_algebra>::const_view&)>,
public messaging {

Expand All @@ -56,27 +54,14 @@ class kalman_fitting_algorithm

/// Execute the algorithm
///
/// @param det The (default) detector object
/// @param bfield The magnetic field object
/// @param track_candidates All track candidates to fit
///
/// @return A container of the fitted track states
///
output_type operator()(
const default_detector::host& det, const magnetic_field& bfield,
const edm::track_candidate_container<default_algebra>::const_view&
track_candidates) const override;

/// Execute the algorithm
///
/// @param det The (telescope) detector object
/// @param det The detector object
/// @param bfield The magnetic field object
/// @param track_candidates All track candidates to fit
///
/// @return A container of the fitted track states
///
output_type operator()(
const telescope_detector::host& det, const magnetic_field& bfield,
const host_detector& det, const magnetic_field& bfield,
const edm::track_candidate_container<default_algebra>::const_view&
track_candidates) const override;

Expand Down
20 changes: 16 additions & 4 deletions core/include/traccc/geometry/detector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "traccc/definitions/primitives.hpp"

// Detray include(s).
#include <any>
#include <detray/core/detector.hpp>
#include <detray/detectors/default_metadata.hpp>
#include <detray/detectors/telescope_metadata.hpp>
Expand Down Expand Up @@ -47,7 +48,7 @@ struct device_detector_container_types {

/// Base struct for the different detector types supported by the project.
template <typename metadata_t>
struct detector {
struct detector_traits {

/// Metadata type of the detector.
using metadata_type = metadata_t;
Expand All @@ -66,15 +67,26 @@ struct detector {

}; // struct default_detector

template <typename T>
concept is_detector_traits = requires {
typename T::metadata_type;
typename T::host;
typename T::device;
typename T::view;
typename T::buffer;
};

/// Default detector (also used for ODD)
using default_detector =
detector<detray::default_metadata<traccc::default_algebra>>;
detector_traits<detray::default_metadata<traccc::default_algebra>>;

/// Telescope detector
using telescope_detector = detector<
using telescope_detector = detector_traits<
detray::telescope_metadata<traccc::default_algebra, detray::rectangle2D>>;

/// Toy detector
using toy_detector = detector<detray::toy_metadata<traccc::default_algebra>>;
using toy_detector =
detector_traits<detray::toy_metadata<traccc::default_algebra>>;

using detector_type_list = std::tuple<default_detector, telescope_detector>;
} // namespace traccc
117 changes: 117 additions & 0 deletions core/include/traccc/geometry/detector_buffer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Project include(s).
#include "traccc/geometry/detector.hpp"
#include "traccc/geometry/host_detector.hpp"
#include "traccc/geometry/move_only_any.hpp"

// Detray include(s).
#include <any>
#include <detray/core/detector.hpp>
#include <detray/detectors/default_metadata.hpp>
#include <detray/detectors/telescope_metadata.hpp>
#include <detray/detectors/toy_metadata.hpp>

namespace traccc {

class detector_buffer {
public:
detector_buffer() = default;
detector_buffer(const detector_buffer&) = delete;
detector_buffer(detector_buffer&&) = default;
detector_buffer& operator=(const detector_buffer&) = delete;
detector_buffer& operator=(detector_buffer&&) = default;

template <typename detector_traits_t>
void set(typename detector_traits_t::buffer&& obj)
requires(is_detector_traits<detector_traits_t>)
{
m_obj.set<typename detector_traits_t::buffer>(std::move(obj));
}

template <typename detector_traits_t>
bool is() const
requires(is_detector_traits<detector_traits_t>)
{
return (type() == typeid(typename detector_traits_t::buffer));
}

const std::type_info& type() const { return m_obj.type(); }

template <typename detector_traits_t>
const typename detector_traits_t::buffer& as() const
requires(is_detector_traits<detector_traits_t>)
{
return m_obj.as<typename detector_traits_t::buffer>();
}

template <typename detector_traits_t>
typename detector_traits_t::view as_view() const
requires(is_detector_traits<detector_traits_t>)
{
return detray::get_data(as<detector_traits_t>());
}

private:
move_only_any m_obj;
}; // class bfield

/// @brief Helper function for `detector_buffer_visitor`
template <typename callable_t, typename detector_t, typename... detector_ts>
auto detector_buffer_visitor_helper(const detector_buffer& detector_buffer,
callable_t&& callable,
std::tuple<detector_t, detector_ts...>*) {
if (detector_buffer.is<detector_t>()) {
return callable.template operator()<detector_t>(
detector_buffer.as_view<detector_t>());
} else {
if constexpr (sizeof...(detector_ts) > 0) {
return detector_buffer_visitor_helper(
detector_buffer, std::forward<callable_t>(callable),
static_cast<std::tuple<detector_ts...>*>(nullptr));
} else {
std::stringstream exception_message;

exception_message
<< "Invalid detector type (" << detector_buffer.type().name()
<< ") received, but this type is not supported" << std::endl;

throw std::invalid_argument(exception_message.str());
}
}
}

/// @brief Visitor for polymorphic detector buffer types
///
/// This function takes a list of supported detector trait types and checks
/// if the provided field is one of them. If it is, it will call the provided
/// callable on a view of it and otherwise it will throw an exception.
template <typename detector_buffer_list_t, typename callable_t>
auto detector_buffer_visitor(const detector_buffer& detector_buffer,
callable_t&& callable) {
return detector_buffer_visitor_helper(
detector_buffer, std::forward<callable_t>(callable),
static_cast<detector_buffer_list_t*>(nullptr));
}

// TODO: Docs
inline detector_buffer buffer_from_host_detector(const host_detector& det,
vecmem::memory_resource& mr,
vecmem::copy& copy) {
return host_detector_visitor<detector_type_list>(
det, [&mr, &copy]<typename detector_traits_t>(
const typename detector_traits_t::host& detector) {
traccc::detector_buffer rv;
rv.set<detector_traits_t>(detray::get_buffer(detector, mr, copy));
return rv;
});
}

} // namespace traccc
Loading
Loading