|
13 | 13 | #include <iosfwd> |
14 | 14 | #include <memory> |
15 | 15 |
|
| 16 | +#include "atlas/util/detail/nanoflann.hpp" |
| 17 | + |
16 | 18 | #include "eckit/container/KDTree.h" |
17 | 19 |
|
18 | 20 | #include "atlas/library/config.h" |
@@ -432,6 +434,170 @@ void KDTree_eckit<TreeT, PayloadT, PointT>::assert_built() const { |
432 | 434 |
|
433 | 435 | //------------------------------------------------------------------------------------------------------ |
434 | 436 |
|
| 437 | +template<typename ValueTy> |
| 438 | +class ValueListAdaptor { |
| 439 | +public: |
| 440 | + ValueListAdaptor() = default; |
| 441 | + explicit ValueListAdaptor(const std::vector<ValueTy>& values) : pts(values) {} |
| 442 | + |
| 443 | + // --- Nanoflann interface --- |
| 444 | + inline size_t kdtree_get_point_count() const { return pts.size(); } |
| 445 | + |
| 446 | + inline double kdtree_get_pt(const size_t idx, const size_t dim) const { |
| 447 | + return pts[idx].point()[dim]; |
| 448 | + } |
| 449 | + |
| 450 | + template <class BBOX> |
| 451 | + bool kdtree_get_bbox(BBOX&) const { return false; } |
| 452 | + // ---------------------------- |
| 453 | + |
| 454 | + std::vector<ValueTy> pts; |
| 455 | +}; |
| 456 | + |
| 457 | +// Note - KDTree_nanoflann is currently designed as an alternative to KDTreeMemory, |
| 458 | +// but would probably be better as an alternative to KDTree_eckit. It can be used |
| 459 | +// in its current form as follows: |
| 460 | +// ``` |
| 461 | +// auto nanoflann_kdtree_ = std::make_shared< |
| 462 | +// util::detail::KDTree_nanoflann<atlas::idx_t, eckit::geometry::Point3>>(); |
| 463 | +// |
| 464 | +// atlas::util::IndexKDTree atlas_kdtree_wrapper_{nanoflann_kdtree_, atlas::util::Geometry{}}; |
| 465 | +// ``` |
| 466 | +// Note - this implementation is working but should be considered a draft. |
| 467 | +// Note - nanoflann is thread-safe for queries (https://github.com/jlblancoc/nanoflann/issues/54) |
| 468 | +template <typename PayloadTy, typename PointTy = eckit::geometry::Point3> |
| 469 | +class KDTree_nanoflann { |
| 470 | +public: |
| 471 | + using Interface = atlas::util::detail::KDTreeBase<PayloadTy, PointTy>; |
| 472 | + using Payload = PayloadTy; |
| 473 | + using Point = PointTy; |
| 474 | + using Value = typename Interface::Value; |
| 475 | + using ValueList = typename Interface::ValueList; |
| 476 | + using DataSetAdaptor = ValueListAdaptor<Value>; |
| 477 | + constexpr static int DIMS = Point::DIMS; |
| 478 | + using Node = Value; |
| 479 | + |
| 480 | + using Tree = nanoflann::KDTreeSingleIndexAdaptor< |
| 481 | + nanoflann::L2_Simple_Adaptor<double, DataSetAdaptor>, |
| 482 | + DataSetAdaptor, |
| 483 | + DIMS |
| 484 | + >; |
| 485 | + |
| 486 | +private: |
| 487 | + DataSetAdaptor dataset_; |
| 488 | + std::shared_ptr<Tree> index_; |
| 489 | + |
| 490 | +public: |
| 491 | + KDTree_nanoflann() {} |
| 492 | + |
| 493 | + KDTree_nanoflann(const std::shared_ptr<Tree>& tree): index_(tree) {} |
| 494 | + |
| 495 | + atlas::idx_t size() const { |
| 496 | + return static_cast<atlas::idx_t>(dataset_.pts.size()); |
| 497 | + } |
| 498 | + |
| 499 | + void build() { |
| 500 | + if (!dataset_.pts.empty()) { |
| 501 | + index_ = std::make_shared<Tree>( |
| 502 | + DIMS, dataset_, nanoflann::KDTreeSingleIndexAdaptorParams(10 /* max leaf */) |
| 503 | + ); |
| 504 | + index_->buildIndex(); |
| 505 | + } else { |
| 506 | + index_.reset(); |
| 507 | + } |
| 508 | + } |
| 509 | + |
| 510 | + void build(std::vector<Value>& values) { |
| 511 | + dataset_.pts.clear(); |
| 512 | + dataset_.pts.reserve(values.size()); |
| 513 | + |
| 514 | + for (const auto& value : values) { |
| 515 | + dataset_.pts.emplace_back(value.point(), value.payload()); |
| 516 | + } |
| 517 | + |
| 518 | + build(); |
| 519 | + } |
| 520 | + |
| 521 | + /// @brief Insert 3D cartesian point (x,y,z) |
| 522 | + /// If memory has been reserved with reserve(), insertion will be delayed until build() is called. |
| 523 | + void insert(const Value& value) { |
| 524 | + // Insert immediately and rebuild index |
| 525 | + dataset_.pts.emplace_back(value.point(), value.payload()); |
| 526 | + build(); |
| 527 | + } |
| 528 | + |
| 529 | + /// @brief Find k nearest neighbours given a 3D cartesian point (x,y,z) |
| 530 | + ValueList kNearestNeighbours(const Point& query_point, size_t k) const { |
| 531 | + if (!index_) { |
| 532 | + throw_AssertionFailed("KDTree was used before calling build()"); |
| 533 | + } |
| 534 | + |
| 535 | + double query_pt[DIMS]; |
| 536 | + for (int i = 0; i < DIMS; ++i) { |
| 537 | + query_pt[i] = query_point[i]; |
| 538 | + } |
| 539 | + |
| 540 | + std::vector<size_t> ret_indexes(k); |
| 541 | + std::vector<double> out_dists_sqr(k); |
| 542 | + |
| 543 | + nanoflann::KNNResultSet<double> resultSet(k); |
| 544 | + resultSet.init(ret_indexes.data(), out_dists_sqr.data()); |
| 545 | + |
| 546 | + bool found_k_neighbors = index_->findNeighbors(resultSet, query_pt); |
| 547 | + if (size() >= k && !found_k_neighbors) { |
| 548 | + throw_AssertionFailed("KDTree::kNearestNeighbours: not enough neighbors found"); |
| 549 | + } |
| 550 | + |
| 551 | + std::vector<Value> results; |
| 552 | + results.reserve(k); |
| 553 | + |
| 554 | + for (size_t i = 0; i < k; ++i) { |
| 555 | + const auto& value = dataset_.pts[ret_indexes[i]]; |
| 556 | + results.emplace_back(value.point(), value.payload(), std::sqrt(out_dists_sqr[i])); |
| 557 | + } |
| 558 | + |
| 559 | + return ValueList{results}; |
| 560 | + } |
| 561 | + |
| 562 | + /// @brief Find nearest neighbour given a 3D cartesian point (x,y,z) |
| 563 | + Value nearestNeighbour(const Point& query_point) const { |
| 564 | + if (!index_) { |
| 565 | + throw_AssertionFailed("KDTree was used before calling build()"); |
| 566 | + } |
| 567 | + auto results = kNearestNeighbours(query_point, 1); |
| 568 | + if (results.empty()) { |
| 569 | + // Return a default value if no points found |
| 570 | + return Value{query_point, Payload{}, std::numeric_limits<double>::max()}; |
| 571 | + } |
| 572 | + return results[0]; |
| 573 | + } |
| 574 | + |
| 575 | + /// @brief Find all points within a distance of given radius from a given point (x,y,z) |
| 576 | + ValueList findInSphere(const Point& query_point, double radius) const { |
| 577 | + if (!index_) { |
| 578 | + throw_AssertionFailed("KDTree was used before calling build()"); |
| 579 | + } |
| 580 | + |
| 581 | + double query_pt[3] = {query_point[0], query_point[1], query_point[2]}; |
| 582 | + double radius_sqr = radius * radius; |
| 583 | + |
| 584 | + std::vector<nanoflann::ResultItem<size_t, double>> indices_dists; |
| 585 | + nanoflann::RadiusResultSet<double, size_t> resultSet(radius_sqr, indices_dists); |
| 586 | + |
| 587 | + index_->findNeighbors(resultSet, query_pt); |
| 588 | + |
| 589 | + std::vector<Value> results; |
| 590 | + results.reserve(indices_dists.size()); |
| 591 | + |
| 592 | + for (const auto& idx_dist : indices_dists) { |
| 593 | + const auto& value = dataset_.pts[idx_dist.first]; |
| 594 | + results.emplace_back(value.point(), value.payload(), std::sqrt(idx_dist.second)); |
| 595 | + } |
| 596 | + |
| 597 | + return ValueList{results}; |
| 598 | + } |
| 599 | +}; |
| 600 | + |
435 | 601 | } // namespace detail |
436 | 602 | } // namespace util |
437 | 603 | } // namespace atlas |
0 commit comments