Skip to content

Commit 415fbfd

Browse files
committed
Add add a wrapper on nanoflann that is compatible with Atlas KDTree.
1 parent 49d5f36 commit 415fbfd

File tree

3 files changed

+3146
-0
lines changed

3 files changed

+3146
-0
lines changed

src/atlas/util/detail/KDTree.h

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <iosfwd>
1414
#include <memory>
1515

16+
#include "atlas/util/detail/nanoflann.hpp"
17+
1618
#include "eckit/container/KDTree.h"
1719

1820
#include "atlas/library/config.h"
@@ -432,6 +434,170 @@ void KDTree_eckit<TreeT, PayloadT, PointT>::assert_built() const {
432434

433435
//------------------------------------------------------------------------------------------------------
434436

437+
template<typename ValueTy>
438+
class ValueListAdaptor {
439+
public:
440+
ValueListAdaptor() = default;
441+
explicit ValueListAdaptor(const std::vector<ValueTy>& values) : pts(values) {}
442+
443+
// --- Nanoflann interface ---
444+
inline size_t kdtree_get_point_count() const { return pts.size(); }
445+
446+
inline double kdtree_get_pt(const size_t idx, const size_t dim) const {
447+
return pts[idx].point()[dim];
448+
}
449+
450+
template <class BBOX>
451+
bool kdtree_get_bbox(BBOX&) const { return false; }
452+
// ----------------------------
453+
454+
std::vector<ValueTy> pts;
455+
};
456+
457+
// Note - KDTree_nanoflann is currently designed as an alternative to KDTreeMemory,
458+
// but would probably be better as an alternative to KDTree_eckit. It can be used
459+
// in its current form as follows:
460+
// ```
461+
// auto nanoflann_kdtree_ = std::make_shared<
462+
// util::detail::KDTree_nanoflann<atlas::idx_t, eckit::geometry::Point3>>();
463+
//
464+
// atlas::util::IndexKDTree atlas_kdtree_wrapper_{nanoflann_kdtree_, atlas::util::Geometry{}};
465+
// ```
466+
// Note - this implementation is working but should be considered a draft.
467+
// Note - nanoflann is thread-safe for queries (https://github.com/jlblancoc/nanoflann/issues/54)
468+
template <typename PayloadTy, typename PointTy = eckit::geometry::Point3>
469+
class KDTree_nanoflann {
470+
public:
471+
using Interface = atlas::util::detail::KDTreeBase<PayloadTy, PointTy>;
472+
using Payload = PayloadTy;
473+
using Point = PointTy;
474+
using Value = typename Interface::Value;
475+
using ValueList = typename Interface::ValueList;
476+
using DataSetAdaptor = ValueListAdaptor<Value>;
477+
constexpr static int DIMS = Point::DIMS;
478+
using Node = Value;
479+
480+
using Tree = nanoflann::KDTreeSingleIndexAdaptor<
481+
nanoflann::L2_Simple_Adaptor<double, DataSetAdaptor>,
482+
DataSetAdaptor,
483+
DIMS
484+
>;
485+
486+
private:
487+
DataSetAdaptor dataset_;
488+
std::shared_ptr<Tree> index_;
489+
490+
public:
491+
KDTree_nanoflann() {}
492+
493+
KDTree_nanoflann(const std::shared_ptr<Tree>& tree): index_(tree) {}
494+
495+
atlas::idx_t size() const {
496+
return static_cast<atlas::idx_t>(dataset_.pts.size());
497+
}
498+
499+
void build() {
500+
if (!dataset_.pts.empty()) {
501+
index_ = std::make_shared<Tree>(
502+
DIMS, dataset_, nanoflann::KDTreeSingleIndexAdaptorParams(10 /* max leaf */)
503+
);
504+
index_->buildIndex();
505+
} else {
506+
index_.reset();
507+
}
508+
}
509+
510+
void build(std::vector<Value>& values) {
511+
dataset_.pts.clear();
512+
dataset_.pts.reserve(values.size());
513+
514+
for (const auto& value : values) {
515+
dataset_.pts.emplace_back(value.point(), value.payload());
516+
}
517+
518+
build();
519+
}
520+
521+
/// @brief Insert 3D cartesian point (x,y,z)
522+
/// If memory has been reserved with reserve(), insertion will be delayed until build() is called.
523+
void insert(const Value& value) {
524+
// Insert immediately and rebuild index
525+
dataset_.pts.emplace_back(value.point(), value.payload());
526+
build();
527+
}
528+
529+
/// @brief Find k nearest neighbours given a 3D cartesian point (x,y,z)
530+
ValueList kNearestNeighbours(const Point& query_point, size_t k) const {
531+
if (!index_) {
532+
throw_AssertionFailed("KDTree was used before calling build()");
533+
}
534+
535+
double query_pt[DIMS];
536+
for (int i = 0; i < DIMS; ++i) {
537+
query_pt[i] = query_point[i];
538+
}
539+
540+
std::vector<size_t> ret_indexes(k);
541+
std::vector<double> out_dists_sqr(k);
542+
543+
nanoflann::KNNResultSet<double> resultSet(k);
544+
resultSet.init(ret_indexes.data(), out_dists_sqr.data());
545+
546+
bool found_k_neighbors = index_->findNeighbors(resultSet, query_pt);
547+
if (size() >= k && !found_k_neighbors) {
548+
throw_AssertionFailed("KDTree::kNearestNeighbours: not enough neighbors found");
549+
}
550+
551+
std::vector<Value> results;
552+
results.reserve(k);
553+
554+
for (size_t i = 0; i < k; ++i) {
555+
const auto& value = dataset_.pts[ret_indexes[i]];
556+
results.emplace_back(value.point(), value.payload(), std::sqrt(out_dists_sqr[i]));
557+
}
558+
559+
return ValueList{results};
560+
}
561+
562+
/// @brief Find nearest neighbour given a 3D cartesian point (x,y,z)
563+
Value nearestNeighbour(const Point& query_point) const {
564+
if (!index_) {
565+
throw_AssertionFailed("KDTree was used before calling build()");
566+
}
567+
auto results = kNearestNeighbours(query_point, 1);
568+
if (results.empty()) {
569+
// Return a default value if no points found
570+
return Value{query_point, Payload{}, std::numeric_limits<double>::max()};
571+
}
572+
return results[0];
573+
}
574+
575+
/// @brief Find all points within a distance of given radius from a given point (x,y,z)
576+
ValueList findInSphere(const Point& query_point, double radius) const {
577+
if (!index_) {
578+
throw_AssertionFailed("KDTree was used before calling build()");
579+
}
580+
581+
double query_pt[3] = {query_point[0], query_point[1], query_point[2]};
582+
double radius_sqr = radius * radius;
583+
584+
std::vector<nanoflann::ResultItem<size_t, double>> indices_dists;
585+
nanoflann::RadiusResultSet<double, size_t> resultSet(radius_sqr, indices_dists);
586+
587+
index_->findNeighbors(resultSet, query_pt);
588+
589+
std::vector<Value> results;
590+
results.reserve(indices_dists.size());
591+
592+
for (const auto& idx_dist : indices_dists) {
593+
const auto& value = dataset_.pts[idx_dist.first];
594+
results.emplace_back(value.point(), value.payload(), std::sqrt(idx_dist.second));
595+
}
596+
597+
return ValueList{results};
598+
}
599+
};
600+
435601
} // namespace detail
436602
} // namespace util
437603
} // namespace atlas

0 commit comments

Comments
 (0)