Skip to content

Commit b3ce51f

Browse files
committed
Add Node and NodeList. Add KDTreeBase base class.
1 parent 415fbfd commit b3ce51f

File tree

1 file changed

+52
-12
lines changed

1 file changed

+52
-12
lines changed

src/atlas/util/detail/KDTree.h

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,20 @@ class ValueListAdaptor {
454454
std::vector<ValueTy> pts;
455455
};
456456

457+
template<typename ValueTy>
458+
class KDTree_nanoflann_node {
459+
public:
460+
KDTree_nanoflann_node(const ValueTy& value) : value_(value) {}
461+
462+
const ValueTy& value() const { return value_; }
463+
auto point() const { return value_.point(); }
464+
auto payload() const { return value_.payload(); }
465+
double distance() const { return value_.distance(); }
466+
467+
private:
468+
ValueTy value_;
469+
};
470+
457471
// Note - KDTree_nanoflann is currently designed as an alternative to KDTreeMemory,
458472
// but would probably be better as an alternative to KDTree_eckit. It can be used
459473
// in its current form as follows:
@@ -475,7 +489,8 @@ class KDTree_nanoflann {
475489
using ValueList = typename Interface::ValueList;
476490
using DataSetAdaptor = ValueListAdaptor<Value>;
477491
constexpr static int DIMS = Point::DIMS;
478-
using Node = Value;
492+
using Node = KDTree_nanoflann_node<Value>;
493+
using NodeList = std::vector<Node>;
479494

480495
using Tree = nanoflann::KDTreeSingleIndexAdaptor<
481496
nanoflann::L2_Simple_Adaptor<double, DataSetAdaptor>,
@@ -492,6 +507,13 @@ class KDTree_nanoflann {
492507

493508
KDTree_nanoflann(const std::shared_ptr<Tree>& tree): index_(tree) {}
494509

510+
using iterator = typename std::vector<Value>::iterator;
511+
using const_iterator = typename std::vector<Value>::const_iterator;
512+
iterator begin() { return dataset_.pts.begin(); }
513+
iterator end() { return dataset_.pts.end(); }
514+
const_iterator begin() const { return dataset_.pts.begin(); }
515+
const_iterator end() const { return dataset_.pts.end(); }
516+
495517
atlas::idx_t size() const {
496518
return static_cast<atlas::idx_t>(dataset_.pts.size());
497519
}
@@ -518,6 +540,17 @@ class KDTree_nanoflann {
518540
build();
519541
}
520542

543+
template<typename Iterator>
544+
void build(Iterator begin, Iterator end) {
545+
dataset_.pts.clear();
546+
dataset_.pts.reserve(std::distance(begin, end));
547+
for (auto it = begin; it != end; ++it) {
548+
const auto& value = *it;
549+
dataset_.pts.emplace_back(value.point(), value.payload());
550+
}
551+
build();
552+
}
553+
521554
/// @brief Insert 3D cartesian point (x,y,z)
522555
/// If memory has been reserved with reserve(), insertion will be delayed until build() is called.
523556
void insert(const Value& value) {
@@ -527,7 +560,7 @@ class KDTree_nanoflann {
527560
}
528561

529562
/// @brief Find k nearest neighbours given a 3D cartesian point (x,y,z)
530-
ValueList kNearestNeighbours(const Point& query_point, size_t k) const {
563+
NodeList kNearestNeighbours(const Point& query_point, size_t k) const {
531564
if (!index_) {
532565
throw_AssertionFailed("KDTree was used before calling build()");
533566
}
@@ -548,32 +581,32 @@ class KDTree_nanoflann {
548581
throw_AssertionFailed("KDTree::kNearestNeighbours: not enough neighbors found");
549582
}
550583

551-
std::vector<Value> results;
584+
std::vector<Node> results;
552585
results.reserve(k);
553586

554587
for (size_t i = 0; i < k; ++i) {
555588
const auto& value = dataset_.pts[ret_indexes[i]];
556-
results.emplace_back(value.point(), value.payload(), std::sqrt(out_dists_sqr[i]));
589+
results.emplace_back(Value{value.point(), value.payload(), std::sqrt(out_dists_sqr[i])});
557590
}
558-
559-
return ValueList{results};
591+
592+
return NodeList{results};
560593
}
561594

562595
/// @brief Find nearest neighbour given a 3D cartesian point (x,y,z)
563-
Value nearestNeighbour(const Point& query_point) const {
596+
Node nearestNeighbour(const Point& query_point) const {
564597
if (!index_) {
565598
throw_AssertionFailed("KDTree was used before calling build()");
566599
}
567600
auto results = kNearestNeighbours(query_point, 1);
568601
if (results.empty()) {
569602
// Return a default value if no points found
570-
return Value{query_point, Payload{}, std::numeric_limits<double>::max()};
603+
return Node{Value{query_point, Payload{}, std::numeric_limits<double>::max()}};
571604
}
572605
return results[0];
573606
}
574607

575608
/// @brief Find all points within a distance of given radius from a given point (x,y,z)
576-
ValueList findInSphere(const Point& query_point, double radius) const {
609+
NodeList findInSphere(const Point& query_point, double radius) const {
577610
if (!index_) {
578611
throw_AssertionFailed("KDTree was used before calling build()");
579612
}
@@ -586,18 +619,25 @@ class KDTree_nanoflann {
586619

587620
index_->findNeighbors(resultSet, query_pt);
588621

589-
std::vector<Value> results;
622+
std::vector<Node> results;
590623
results.reserve(indices_dists.size());
591624

592625
for (const auto& idx_dist : indices_dists) {
593626
const auto& value = dataset_.pts[idx_dist.first];
594-
results.emplace_back(value.point(), value.payload(), std::sqrt(idx_dist.second));
627+
results.emplace_back(Value{value.point(), value.payload(), std::sqrt(idx_dist.second)});
595628
}
596629

597-
return ValueList{results};
630+
return NodeList{results};
598631
}
599632
};
600633

634+
//------------------------------------------------------------------------------------------------------
635+
636+
template <typename Payload, typename Point>
637+
using KDTreeNanoflann = KDTree_eckit<KDTree_nanoflann<Payload, Point>>;
638+
639+
//------------------------------------------------------------------------------------------------------
640+
601641
} // namespace detail
602642
} // namespace util
603643
} // namespace atlas

0 commit comments

Comments
 (0)