1+ // SPDX-License-Identifier: MIT
2+ // Copyright (c) 2025 Kenji Koide (k.koide@aist.go.jp)
3+ #include < gtsam_points/ann/kdtree_cuda.hpp>
4+
5+ #include < gtsam_points/cuda/check_error.cuh>
6+ #include < gtsam_points/ann/small_kdtree.hpp>
7+
8+ namespace gtsam_points {
9+
10+ KdTreeGPU::KdTreeGPU (const PointCloud::ConstPtr& points, CUstream_st* stream)
11+ : points(points),
12+ num_indices (0 ),
13+ num_nodes(0 ),
14+ indices(nullptr ),
15+ nodes(nullptr ) {
16+ //
17+ if (!points->has_points ()) {
18+ std::cerr << " error: empty point cloud is given for KdTreeGPU" << std::endl;
19+ return ;
20+ }
21+ if (!points->has_points_gpu ()) {
22+ std::cerr << " error: point cloud does not have GPU points for KdTreeGPU" << std::endl;
23+ return ;
24+ }
25+
26+ //
27+ KdTreeBuilder builder;
28+ UnsafeKdTree<PointCloud> kdtree (*points, builder);
29+
30+ // copy to GPU
31+ std::vector<std::uint32_t > h_indices (kdtree.indices .begin (), kdtree.indices .end ());
32+ std::vector<KdTreeNodeGPU> h_nodes (kdtree.nodes .size ());
33+
34+ for (int i = 0 ; i < kdtree.nodes .size (); i++) {
35+ const auto & in = kdtree.nodes [i];
36+ auto & out = h_nodes[i];
37+
38+ out.left = in.left ;
39+ out.right = in.right ;
40+
41+ if (in.left == INVALID_NODE) {
42+ out.node_type .lr .first = in.node_type .lr .first ;
43+ out.node_type .lr .last = in.node_type .lr .last ;
44+ } else {
45+ out.node_type .sub .axis = in.node_type .sub .proj .axis ;
46+ out.node_type .sub .thresh = in.node_type .sub .thresh ;
47+ }
48+ }
49+
50+ num_indices = kdtree.indices .size ();
51+ num_nodes = kdtree.nodes .size ();
52+ check_error << cudaMallocAsync (&indices, sizeof (std::uint32_t ) * num_indices, stream);
53+ check_error << cudaMallocAsync (&nodes, sizeof (KdTreeNodeGPU) * num_nodes, stream);
54+ check_error << cudaMemcpyAsync (indices, h_indices.data (), sizeof (std::uint32_t ) * num_indices, cudaMemcpyHostToDevice, stream);
55+ check_error << cudaMemcpyAsync (nodes, h_nodes.data (), sizeof (KdTreeNodeGPU) * num_nodes, cudaMemcpyHostToDevice, stream);
56+ }
57+
58+ KdTreeGPU::~KdTreeGPU () {
59+ check_error << cudaFreeAsync (indices, nullptr );
60+ check_error << cudaFreeAsync (nodes, nullptr );
61+ }
62+
63+ void KdTreeGPU::nearest_neighbor_search_cpu (
64+ const Eigen::Vector3f* h_queries,
65+ size_t num_queries,
66+ std::uint32_t * h_nn_indices,
67+ float * h_nn_sq_dists,
68+ CUstream_st* stream) {
69+ //
70+ Eigen::Vector3f* d_queries;
71+ std::uint32_t * d_nn_indices;
72+ float * d_nn_sq_dists;
73+
74+ check_error << cudaMallocAsync (&d_queries, sizeof (Eigen::Vector3f) * num_queries, stream);
75+ check_error << cudaMallocAsync (&d_nn_indices, sizeof (std::uint32_t ) * num_queries, stream);
76+ check_error << cudaMallocAsync (&d_nn_sq_dists, sizeof (float ) * num_queries, stream);
77+ check_error << cudaMemcpyAsync (d_queries, h_queries, sizeof (Eigen::Vector3f) * num_queries, cudaMemcpyHostToDevice, stream);
78+
79+ nearest_neighbor_search (d_queries, num_queries, d_nn_indices, d_nn_sq_dists, stream);
80+
81+ check_error << cudaMemcpyAsync (h_nn_indices, d_nn_indices, sizeof (std::uint32_t ) * num_queries, cudaMemcpyDeviceToHost, stream);
82+ check_error << cudaMemcpyAsync (h_nn_sq_dists, d_nn_sq_dists, sizeof (float ) * num_queries, cudaMemcpyDeviceToHost, stream);
83+
84+ check_error << cudaFreeAsync (d_queries, stream);
85+ check_error << cudaFreeAsync (d_nn_indices, stream);
86+ check_error << cudaFreeAsync (d_nn_sq_dists, stream);
87+ }
88+
89+ } // namespace gtsam_points
0 commit comments