@@ -261,14 +261,43 @@ combinatorial_kalman_filter(
261261 {
262262 vecmem::data::vector_buffer<candidate_link> tmp_links_buffer (
263263 n_max_candidates, mr.main );
264- copy.setup (tmp_links_buffer)->ignore ();
264+ copy.setup (tmp_links_buffer)->wait ();
265265 bound_track_parameters_collection_types::buffer tmp_params_buffer (
266266 n_max_candidates, mr.main );
267- copy.setup (tmp_params_buffer)->ignore ();
267+ copy.setup (tmp_params_buffer)->wait ();
268268
269269 // The number of threads to use per block in the track finding.
270270 static const unsigned int nFindTracksThreads = 64 ;
271271
272+ // Allocate the kernel's payload in host memory.
273+ using payload_t = device::find_tracks_payload<detector_t >;
274+ const payload_t host_payload{
275+ .det_data = det,
276+ .measurements_view = measurements,
277+ .in_params_view = in_params_buffer,
278+ .in_params_liveness_view = param_liveness_buffer,
279+ .n_in_params = n_in_params,
280+ .barcodes_view = barcodes_buffer,
281+ .upper_bounds_view = upper_bounds_buffer,
282+ .links_view = links_buffer,
283+ .prev_links_idx =
284+ (step == 0 ? 0 : step_to_link_idx_map[step - 1 ]),
285+ .curr_links_idx = step_to_link_idx_map[step],
286+ .step = step,
287+ .out_params_view = updated_params_buffer,
288+ .out_params_liveness_view = updated_liveness_buffer,
289+ .tips_view = tips_buffer,
290+ .tip_lengths_view = tip_length_buffer,
291+ .n_tracks_per_seed_view = n_tracks_per_seed_buffer,
292+ .tmp_params_view = tmp_params_buffer,
293+ .tmp_links_view = tmp_links_buffer};
294+ // Now copy it to device memory.
295+ vecmem::data::vector_buffer<payload_t > device_payload (1u , mr.main );
296+ copy.setup (device_payload)->wait ();
297+ copy (vecmem::data::vector_view<const payload_t >(1u , &host_payload),
298+ device_payload)
299+ ->wait ();
300+
272301 // Submit the kernel to the queue.
273302 queue
274303 .submit ([&](::sycl::handler& h) {
@@ -287,27 +316,7 @@ combinatorial_kalman_filter(
287316 // Launch the kernel.
288317 h.parallel_for <kernels::find_tracks<kernel_t >>(
289318 calculate1DimNdRange (n_in_params, nFindTracksThreads),
290- [config, det, measurements,
291- in_params = vecmem::get_data (in_params_buffer),
292- param_liveness =
293- vecmem::get_data (param_liveness_buffer),
294- n_in_params,
295- barcodes = vecmem::get_data (barcodes_buffer),
296- upper_bounds = vecmem::get_data (upper_bounds_buffer),
297- links_view = vecmem::get_data (links_buffer),
298- prev_links_idx =
299- step == 0 ? 0 : step_to_link_idx_map[step - 1 ],
300- curr_links_idx = step_to_link_idx_map[step], step,
301- updated_params =
302- vecmem::get_data (updated_params_buffer),
303- updated_liveness =
304- vecmem::get_data (updated_liveness_buffer),
305- tips = vecmem::get_data (tips_buffer),
306- tip_lengths = vecmem::get_data (tip_length_buffer),
307- n_tracks_per_seed =
308- vecmem::get_data (n_tracks_per_seed_buffer),
309- tmp_params = vecmem::get_data (tmp_params_buffer),
310- tmp_links = vecmem::get_data (tmp_links_buffer),
319+ [config, payload = device_payload.ptr (),
311320 shared_insertion_mutex, shared_candidates,
312321 shared_candidates_size, shared_num_out_params,
313322 shared_out_offset](::sycl::nd_item<1 > item) {
@@ -317,13 +326,7 @@ combinatorial_kalman_filter(
317326
318327 // Call the device function to find tracks.
319328 device::find_tracks<detector_t >(
320- thread_id, barrier, config,
321- {det, measurements, in_params, param_liveness,
322- n_in_params, barcodes, upper_bounds,
323- links_view, prev_links_idx, curr_links_idx,
324- step, updated_params, updated_liveness, tips,
325- tip_lengths, n_tracks_per_seed, tmp_params,
326- tmp_links},
329+ thread_id, barrier, config, *payload,
327330 {shared_num_out_params[0 ], shared_out_offset[0 ],
328331 &(shared_insertion_mutex[0 ]),
329332 &(shared_candidates[0 ]),
@@ -478,34 +481,50 @@ combinatorial_kalman_filter(
478481 * Kernel5: Propagate to the next surface
479482 *****************************************************************/
480483
481- // Launch the kernel to propagate all active tracks to the next
482- // surface.
483- queue
484- .submit ([&](::sycl::handler& h) {
485- h.parallel_for <
486- kernels::propagate_to_next_surface<kernel_t >>(
487- calculate1DimNdRange (n_candidates, 64 ),
488- [config, det, field,
489- in_params = vecmem::get_data (in_params_buffer),
490- param_liveness =
491- vecmem::get_data (param_liveness_buffer),
492- param_ids = vecmem::get_data (param_ids_buffer),
493- links_view = vecmem::get_data (links_buffer),
494- prev_links_idx = step_to_link_idx_map[step], step,
495- n_candidates, tips = vecmem::get_data (tips_buffer),
496- tip_lengths = vecmem::get_data (tip_length_buffer)](
497- ::sycl::nd_item<1 > item) {
498- device::propagate_to_next_surface<
499- traccc::details::ckf_propagator_t <detector_t ,
500- bfield_t >,
501- bfield_t >(
502- details::global_index (item), config,
503- {det, field, in_params, param_liveness,
504- param_ids, links_view, prev_links_idx, step,
505- n_candidates, tips, tip_lengths});
506- });
507- })
508- .wait_and_throw ();
484+ {
485+ // Allocate the kernel's payload in host memory.
486+ using payload_t = device::propagate_to_next_surface_payload<
487+ traccc::details::ckf_propagator_t <detector_t , bfield_t >,
488+ bfield_t >;
489+ const payload_t host_payload{
490+ .det_data = det,
491+ .field_data = field,
492+ .params_view = in_params_buffer,
493+ .params_liveness_view = param_liveness_buffer,
494+ .param_ids_view = param_ids_buffer,
495+ .links_view = links_buffer,
496+ .prev_links_idx = step_to_link_idx_map[step],
497+ .step = step,
498+ .n_in_params = n_candidates,
499+ .tips_view = tips_buffer,
500+ .tip_lengths_view = tip_length_buffer};
501+ // Now copy it to device memory.
502+ vecmem::data::vector_buffer<payload_t > device_payload (1u ,
503+ mr.main );
504+ copy.setup (device_payload)->wait ();
505+ copy (vecmem::data::vector_view<const payload_t >(1u ,
506+ &host_payload),
507+ device_payload)
508+ ->wait ();
509+
510+ // Launch the kernel to propagate all active tracks to the next
511+ // surface.
512+ queue
513+ .submit ([&](::sycl::handler& h) {
514+ h.parallel_for <
515+ kernels::propagate_to_next_surface<kernel_t >>(
516+ calculate1DimNdRange (n_candidates, 64 ),
517+ [config, payload = device_payload.ptr ()](
518+ ::sycl::nd_item<1 > item) {
519+ device::propagate_to_next_surface<
520+ traccc::details::ckf_propagator_t <
521+ detector_t , bfield_t >,
522+ bfield_t >(details::global_index (item),
523+ config, *payload);
524+ });
525+ })
526+ .wait_and_throw ();
527+ }
509528 }
510529
511530 n_in_params = n_candidates;
0 commit comments