@@ -454,6 +454,20 @@ class ValueListAdaptor {
454454 std::vector<ValueTy> pts;
455455};
456456
457+ template <typename ValueTy>
458+ class KDTree_nanoflann_node {
459+ public:
460+ KDTree_nanoflann_node (const ValueTy& value) : value_(value) {}
461+
462+ const ValueTy& value () const { return value_; }
463+ auto point () const { return value_.point (); }
464+ auto payload () const { return value_.payload (); }
465+ double distance () const { return value_.distance (); }
466+
467+ private:
468+ ValueTy value_;
469+ };
470+
457471// Note - KDTree_nanoflann is currently designed as an alternative to KDTreeMemory,
458472// but would probably be better as an alternative to KDTree_eckit. It can be used
459473// in its current form as follows:
@@ -475,7 +489,8 @@ class KDTree_nanoflann {
475489 using ValueList = typename Interface::ValueList;
476490 using DataSetAdaptor = ValueListAdaptor<Value>;
477491 constexpr static int DIMS = Point::DIMS;
478- using Node = Value;
492+ using Node = KDTree_nanoflann_node<Value>;
493+ using NodeList = std::vector<Node>;
479494
480495 using Tree = nanoflann::KDTreeSingleIndexAdaptor<
481496 nanoflann::L2_Simple_Adaptor<double , DataSetAdaptor>,
@@ -492,6 +507,13 @@ class KDTree_nanoflann {
492507
493508 KDTree_nanoflann (const std::shared_ptr<Tree>& tree): index_(tree) {}
494509
510+ using iterator = typename std::vector<Value>::iterator;
511+ using const_iterator = typename std::vector<Value>::const_iterator;
512+ iterator begin () { return dataset_.pts .begin (); }
513+ iterator end () { return dataset_.pts .end (); }
514+ const_iterator begin () const { return dataset_.pts .begin (); }
515+ const_iterator end () const { return dataset_.pts .end (); }
516+
495517 atlas::idx_t size () const {
496518 return static_cast <atlas::idx_t >(dataset_.pts .size ());
497519 }
@@ -518,6 +540,17 @@ class KDTree_nanoflann {
518540 build ();
519541 }
520542
543+ template <typename Iterator>
544+ void build (Iterator begin, Iterator end) {
545+ dataset_.pts .clear ();
546+ dataset_.pts .reserve (std::distance (begin, end));
547+ for (auto it = begin; it != end; ++it) {
548+ const auto & value = *it;
549+ dataset_.pts .emplace_back (value.point (), value.payload ());
550+ }
551+ build ();
552+ }
553+
521554 // / @brief Insert 3D cartesian point (x,y,z)
522555 // / If memory has been reserved with reserve(), insertion will be delayed until build() is called.
523556 void insert (const Value& value) {
@@ -527,7 +560,7 @@ class KDTree_nanoflann {
527560 }
528561
529562 // / @brief Find k nearest neighbours given a 3D cartesian point (x,y,z)
530- ValueList kNearestNeighbours (const Point& query_point, size_t k) const {
563+ NodeList kNearestNeighbours (const Point& query_point, size_t k) const {
531564 if (!index_) {
532565 throw_AssertionFailed (" KDTree was used before calling build()" );
533566 }
@@ -548,32 +581,32 @@ class KDTree_nanoflann {
548581 throw_AssertionFailed (" KDTree::kNearestNeighbours: not enough neighbors found" );
549582 }
550583
551- std::vector<Value > results;
584+ std::vector<Node > results;
552585 results.reserve (k);
553586
554587 for (size_t i = 0 ; i < k; ++i) {
555588 const auto & value = dataset_.pts [ret_indexes[i]];
556- results.emplace_back (value.point (), value.payload (), std::sqrt (out_dists_sqr[i]));
589+ results.emplace_back (Value{ value.point (), value.payload (), std::sqrt (out_dists_sqr[i])} );
557590 }
558-
559- return ValueList {results};
591+
592+ return NodeList {results};
560593 }
561594
562595 // / @brief Find nearest neighbour given a 3D cartesian point (x,y,z)
563- Value nearestNeighbour (const Point& query_point) const {
596+ Node nearestNeighbour (const Point& query_point) const {
564597 if (!index_) {
565598 throw_AssertionFailed (" KDTree was used before calling build()" );
566599 }
567600 auto results = kNearestNeighbours (query_point, 1 );
568601 if (results.empty ()) {
569602 // Return a default value if no points found
570- return Value{query_point, Payload{}, std::numeric_limits<double >::max ()};
603+ return Node{ Value{query_point, Payload{}, std::numeric_limits<double >::max ()} };
571604 }
572605 return results[0 ];
573606 }
574607
575608 // / @brief Find all points within a distance of given radius from a given point (x,y,z)
576- ValueList findInSphere (const Point& query_point, double radius) const {
609+ NodeList findInSphere (const Point& query_point, double radius) const {
577610 if (!index_) {
578611 throw_AssertionFailed (" KDTree was used before calling build()" );
579612 }
@@ -586,18 +619,25 @@ class KDTree_nanoflann {
586619
587620 index_->findNeighbors (resultSet, query_pt);
588621
589- std::vector<Value > results;
622+ std::vector<Node > results;
590623 results.reserve (indices_dists.size ());
591624
592625 for (const auto & idx_dist : indices_dists) {
593626 const auto & value = dataset_.pts [idx_dist.first ];
594- results.emplace_back (value.point (), value.payload (), std::sqrt (idx_dist.second ));
627+ results.emplace_back (Value{ value.point (), value.payload (), std::sqrt (idx_dist.second )} );
595628 }
596629
597- return ValueList {results};
630+ return NodeList {results};
598631 }
599632};
600633
634+ // ------------------------------------------------------------------------------------------------------
635+
636+ template <typename Payload, typename Point>
637+ using KDTreeNanoflann = KDTree_eckit<KDTree_nanoflann<Payload, Point>>;
638+
639+ // ------------------------------------------------------------------------------------------------------
640+
601641} // namespace detail
602642} // namespace util
603643} // namespace atlas
0 commit comments