Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions core/include/traccc/seeding/detail/seeding_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ struct seedfilter_config {
float deltaRMin = 5.f * unit<float>::mm;
// how often do you want to increase the weight of a seed for finding a
// compatible seed?
size_t compatSeedLimit = 2;
unsigned int compatSeedLimit = 2;

// seed weight increase
float good_spB_min_radius = 150.f * unit<float>::mm;
Expand All @@ -211,7 +211,6 @@ struct seedfilter_config {
float good_spB_min_weight = 380.f;

// seed cut
float seed_min_weight = 200.f;
float spB_min_radius = 43.f * unit<float>::mm;
};

Expand Down
17 changes: 0 additions & 17 deletions core/include/traccc/seeding/seed_selecting_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,6 @@ struct seed_selecting_helper {
return !(spB.radius() > filter_config.good_spB_min_radius &&
triplet_weight < filter_config.good_spB_min_weight);
}

/// Cut triplets with criteria
///
/// @param filter_config seed filtering configuration parameters
/// @param spacepoints spacepoint collection
/// @param sp_grid Spacepoint grid
/// @param seed current seed to possibly cut
///
/// @return boolean value
template <typename spacepoint_type>
static TRACCC_HOST_DEVICE bool cut_per_middle_sp(
const seedfilter_config& filter_config, const spacepoint_type& spB,
const scalar weight) {

return (weight > filter_config.seed_min_weight ||
spB.radius() > filter_config.spB_min_radius);
}
};

} // namespace traccc
14 changes: 9 additions & 5 deletions core/src/seeding/seed_filtering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,15 @@ void seed_filtering::operator()(
const traccc::details::spacepoint_grid_types::const_device
sp_grid_accessor(sp_grid_data);
const auto& this_seed = triplets_passing_single_seed_cuts[i].get();
if (seed_selecting_helper::cut_per_middle_sp(
m_filter_config,
spacepoints.at(sp_grid_accessor.bin(
this_seed.sp1.bin_idx)[this_seed.sp1.sp_idx]),
this_seed.weight)) {

const scalar spB_radius =
spacepoints
.at(sp_grid_accessor.bin(
this_seed.sp1.bin_idx)[this_seed.sp1.sp_idx])
.radius();

if (this_seed.weight > 200.f ||
spB_radius > m_filter_config.spB_min_radius) {
triplets_passing_final_cuts.push_back(
triplets_passing_single_seed_cuts[i]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ namespace traccc::device {
/// @param[inout] triplet_view Collection of triplets
///
TRACCC_HOST_DEVICE
inline void update_triplet_weights(
inline void find_triplet_confirmations(
global_index_t globalIndex, const seedfilter_config& filter_config,
const edm::spacepoint_collection::const_view& spacepoints,
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
const triplet_counter_collection_types::const_view& tc_view, scalar* data,
device_triplet_collection_types::view triplet_view);
const device_triplet_collection_types::const_view triplet_view,
vecmem::data::vector_view<unsigned int> num_confirmations_view);

} // namespace traccc::device

// Include the implementation.
#include "traccc/seeding/device/impl/update_triplet_weights.ipp"
#include "traccc/seeding/device/impl/find_triplet_confirmations.ipp"
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@
namespace traccc::device {

TRACCC_HOST_DEVICE
inline void update_triplet_weights(
inline void find_triplet_confirmations(
const global_index_t globalIndex, const seedfilter_config& filter_config,
const edm::spacepoint_collection::const_view& spacepoints_view,
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
const triplet_counter_collection_types::const_view& tc_view, scalar* data,
device_triplet_collection_types::view triplet_view) {
const device_triplet_collection_types::const_view triplet_view,
vecmem::data::vector_view<unsigned int> num_confirmations_view) {

// Check if anything needs to be done.
device_triplet_collection_types::device triplets(triplet_view);
const device_triplet_collection_types::const_device triplets(triplet_view);
if (globalIndex >= triplets.size()) {
return;
}

vecmem::device_vector<unsigned int> num_confirmations(
num_confirmations_view);

// Set up the device containers
const edm::spacepoint_collection::const_device spacepoints{
spacepoints_view};
Expand All @@ -38,7 +42,7 @@ inline void update_triplet_weights(
tc_view);

// Current work item
device_triplet this_triplet = triplets.at(globalIndex);
const device_triplet& this_triplet = triplets.at(globalIndex);

const edm::spacepoint_collection::const_device::const_proxy_type
current_spT = spacepoints.at(this_triplet.spT);
Expand All @@ -52,7 +56,7 @@ inline void update_triplet_weights(
this_triplet.curvature - filter_config.deltaInvHelixDiameter;
const scalar upperLimitCurv =
this_triplet.curvature + filter_config.deltaInvHelixDiameter;
std::size_t num_compat_seedR = 0;
unsigned int num_compat_seedR = 0;

const triplet_counter mb_count =
triplet_counts.at(static_cast<unsigned int>(this_triplet.counter_link));
Expand Down Expand Up @@ -116,7 +120,6 @@ inline void update_triplet_weights(

if (newCompSeed) {
data[num_compat_seedR] = otherTop_r;
this_triplet.weight += filter_config.compatSeedWeight;
num_compat_seedR++;
}

Expand All @@ -125,7 +128,7 @@ inline void update_triplet_weights(
}
}

triplets.at(globalIndex).weight = this_triplet.weight;
num_confirmations.at(globalIndex) = num_compat_seedR;
}

} // namespace traccc::device
34 changes: 27 additions & 7 deletions device/common/include/traccc/seeding/device/impl/select_seeds.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ inline void select_seeds(
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
const triplet_counter_collection_types::const_view& tc_view,
const device_triplet_collection_types::const_view& triplet_view,
const vecmem::data::vector_view<const unsigned int> num_confirmations_view,
device_triplet* data, edm::seed_collection::view seed_view) {

// Check if anything needs to be done.
Expand All @@ -85,6 +86,8 @@ inline void select_seeds(
sp_view);

const device_triplet_collection_types::const_device triplets(triplet_view);
const vecmem::device_vector<const unsigned int> num_confirmations(
num_confirmations_view);
edm::seed_collection::device seeds_device(seed_view);

// Current work item = middle spacepoint
Expand All @@ -99,6 +102,14 @@ inline void select_seeds(

const unsigned int end_triplets_spM =
spM_counter.posTriplets + spM_counter.m_nTriplets;

unsigned int max_num_confirmations = 0;

for (unsigned int i = spM_counter.posTriplets; i < end_triplets_spM; ++i) {
max_num_confirmations =
std::max(max_num_confirmations, num_confirmations.at(i));
}

// iterate over the triplets in the bin
for (unsigned int i = spM_counter.posTriplets; i < end_triplets_spM; ++i) {
device_triplet aTriplet = triplets[i];
Expand All @@ -111,10 +122,19 @@ inline void select_seeds(
const edm::spacepoint_collection::const_device::const_proxy_type spT =
spacepoints.at(spT_idx);

if (num_confirmations.at(i) + 1 < max_num_confirmations) {
continue;
}

// update weight of triplet
seed_selecting_helper::seed_weight(filter_config, spM, spB, spT,
aTriplet.weight);

aTriplet.weight +=
static_cast<scalar>(std::min(num_confirmations.at(i),
filter_config.compatSeedLimit)) *
filter_config.compatSeedWeight;

// check if it is a good triplet
if (!seed_selecting_helper::single_seed_cut(filter_config, spM, spB,
spT, aTriplet.weight)) {
Expand Down Expand Up @@ -165,15 +185,15 @@ inline void select_seeds(
break;
}

// check if it is a good triplet
if (seed_selecting_helper::cut_per_middle_sp(
filter_config, spacepoints.at(aTriplet.spB), aTriplet.weight) ||
n_seeds_per_spM == 0) {
if (spacepoints.at(aTriplet.spB).radius() <=
filter_config.spB_min_radius &&
n_seeds_per_spM > 0) {
continue;
}

n_seeds_per_spM++;
n_seeds_per_spM++;

seeds_device.push_back({aTriplet.spB, aTriplet.spM, aTriplet.spT});
}
seeds_device.push_back({aTriplet.spB, aTriplet.spM, aTriplet.spT});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ inline void select_seeds(
const triplet_counter_spM_collection_types::const_view& spM_tc_view,
const triplet_counter_collection_types::const_view& tc_view,
const device_triplet_collection_types::const_view& triplet_view,
const vecmem::data::vector_view<const unsigned int> num_confirmations_view,
triplet* data, edm::seed_collection::view seed_view);

} // namespace traccc::device
Expand Down
27 changes: 17 additions & 10 deletions device/cuda/src/seeding/seed_finding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
#include "traccc/seeding/device/count_doublets.hpp"
#include "traccc/seeding/device/count_triplets.hpp"
#include "traccc/seeding/device/find_doublets.hpp"
#include "traccc/seeding/device/find_triplet_confirmations.hpp"
#include "traccc/seeding/device/find_triplets.hpp"
#include "traccc/seeding/device/reduce_triplet_counts.hpp"
#include "traccc/seeding/device/select_seeds.hpp"
#include "traccc/seeding/device/update_triplet_weights.hpp"

// VecMem include(s).
#include <vecmem/utils/cuda/copy.hpp>
Expand Down Expand Up @@ -108,22 +108,23 @@ __global__ void find_triplets(
}

/// CUDA kernel for running @c traccc::device::update_triplet_weights
__global__ void update_triplet_weights(
__global__ void find_triplet_confirmations(
seedfilter_config filter_config,
edm::spacepoint_collection::const_view spacepoints,
device::triplet_counter_spM_collection_types::const_view spM_tc,
device::triplet_counter_collection_types::const_view midBot_tc,
device::device_triplet_collection_types::view triplet_view) {
device::device_triplet_collection_types::view triplet_view,
vecmem::data::vector_view<unsigned int> num_confirmations_view) {

// Array for temporary storage of quality parameters for comparing triplets
// within weight updating kernel
extern __shared__ scalar data[];
// Each thread uses compatSeedLimit elements of the array
scalar* dataPos = &data[threadIdx.x * filter_config.compatSeedLimit];

device::update_triplet_weights(details::global_index1(), filter_config,
spacepoints, spM_tc, midBot_tc, dataPos,
triplet_view);
device::find_triplet_confirmations(details::global_index1(), filter_config,
spacepoints, spM_tc, midBot_tc, dataPos,
triplet_view, num_confirmations_view);
}

/// CUDA kernel for running @c traccc::device::select_seeds
Expand All @@ -134,6 +135,7 @@ __global__ void select_seeds(
device::triplet_counter_spM_collection_types::const_view spM_tc,
device::triplet_counter_collection_types::const_view midBot_tc,
device::device_triplet_collection_types::view triplet_view,
const vecmem::data::vector_view<const unsigned int> num_confirmations_view,
edm::seed_collection::view seed_view) {

// Array for temporary storage of triplets for comparing within seed
Expand All @@ -145,7 +147,7 @@ __global__ void select_seeds(

device::select_seeds(details::global_index1(), finder_config, filter_config,
spacepoints, sp_view, spM_tc, midBot_tc, triplet_view,
dataPos, seed_view);
num_confirmations_view, dataPos, seed_view);
}

} // namespace kernels
Expand Down Expand Up @@ -324,6 +326,11 @@ edm::seed_collection::buffer seed_finding::operator()(
triplet_buffer);
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());

vecmem::data::vector_buffer<unsigned int> num_confirmations_buffer(
globalCounter_host->m_nTriplets, m_mr.main);
m_copy.setup(num_confirmations_buffer)->wait();
m_copy.memset(num_confirmations_buffer, 0)->wait();

// Calculate the number of threads and thread blocks to run the weight
// updating kernel for.
const unsigned int nWeightUpdatingThreads = m_warp_size * 2;
Expand All @@ -332,13 +339,13 @@ edm::seed_collection::buffer seed_finding::operator()(
nWeightUpdatingThreads;

// Update the weights of all spacepoint triplets.
kernels::update_triplet_weights<<<
kernels::find_triplet_confirmations<<<
nWeightUpdatingBlocks, nWeightUpdatingThreads,
sizeof(scalar) * m_seedfilter_config.compatSeedLimit *
nWeightUpdatingThreads,
stream>>>(m_seedfilter_config, spacepoints_view,
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
triplet_buffer);
triplet_buffer, num_confirmations_buffer);
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());

// Create result object: collection of seeds
Expand All @@ -362,7 +369,7 @@ edm::seed_collection::buffer seed_finding::operator()(
stream>>>(
m_seedfinder_config, m_seedfilter_config, spacepoints_view, g2_view,
triplet_counter_spM_buffer, triplet_counter_midBot_buffer,
triplet_buffer, seed_buffer);
triplet_buffer, num_confirmations_buffer, seed_buffer);
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());

return seed_buffer;
Expand Down
Loading