|
| 1 | +#include "HNSWIndex.h" |
| 2 | +#include <faiss/index_io.h> |
| 3 | +#include <vector> |
| 4 | +#include <algorithm> |
| 5 | +#include <fstream> |
| 6 | + |
| 7 | +namespace zenann { |
| 8 | + |
| 9 | +HNSWIndex::HNSWIndex(size_t dim, size_t M, size_t efConstruction) |
| 10 | + : IndexBase(dim), idx_(new faiss::IndexHNSWFlat(dim, M)) { |
| 11 | + idx_->hnsw.efConstruction = efConstruction; |
| 12 | +} |
| 13 | + |
| 14 | +HNSWIndex::~HNSWIndex() = default; |
| 15 | + |
| 16 | +void HNSWIndex::train() { |
| 17 | + const auto& data = datastore_->getAll(); |
| 18 | + if (data.empty()) return; |
| 19 | + size_t n = data.size(); |
| 20 | + std::vector<float> flat(n * dimension_); |
| 21 | + for (size_t i = 0; i < n; ++i) { |
| 22 | + std::copy(data[i].begin(), data[i].end(), flat.begin() + i * dimension_); |
| 23 | + } |
| 24 | + idx_->add(n, flat.data()); |
| 25 | +} |
| 26 | + |
| 27 | +SearchResult HNSWIndex::search(const Vector& query, size_t k) const { |
| 28 | + std::vector<faiss::idx_t> labels(k); |
| 29 | + std::vector<float> distances(k); |
| 30 | + idx_->search(1, query.data(), k, distances.data(), labels.data()); |
| 31 | + |
| 32 | + SearchResult result; |
| 33 | + result.indices.assign(labels.begin(), labels.end()); |
| 34 | + result.distances.assign(distances.begin(), distances.end()); |
| 35 | + return result; |
| 36 | +} |
| 37 | + |
| 38 | +SearchResult HNSWIndex::search(const Vector& query, size_t k, size_t efSearch) const { |
| 39 | + idx_->hnsw.efSearch = efSearch; |
| 40 | + return search(query, k); |
| 41 | +} |
| 42 | + |
| 43 | +std::vector<SearchResult> HNSWIndex::search_batch(const Dataset& queries, size_t k) const { |
| 44 | + return search_batch(queries, k, idx_->hnsw.efSearch); |
| 45 | +} |
| 46 | + |
| 47 | +std::vector<SearchResult> HNSWIndex::search_batch(const Dataset& queries, size_t k, size_t efSearch) const { |
| 48 | + idx_->hnsw.efSearch = efSearch; |
| 49 | + size_t nq = queries.size(); |
| 50 | + std::vector<float> flat(nq * dimension_); |
| 51 | + for (size_t i = 0; i < nq; ++i) { |
| 52 | + std::copy(queries[i].begin(), queries[i].end(), flat.begin() + i * dimension_); |
| 53 | + } |
| 54 | + std::vector<faiss::idx_t> labels(nq * k); |
| 55 | + std::vector<float> distances(nq * k); |
| 56 | + idx_->search(nq, flat.data(), k, distances.data(), labels.data()); |
| 57 | + |
| 58 | + std::vector<SearchResult> results(nq); |
| 59 | + for (size_t i = 0; i < nq; ++i) { |
| 60 | + auto lb = labels.begin() + i * k; |
| 61 | + auto dd = distances.begin() + i * k; |
| 62 | + results[i].indices.assign(lb, lb + k); |
| 63 | + results[i].distances.assign(dd, dd + k); |
| 64 | + } |
| 65 | + return results; |
| 66 | +} |
| 67 | + |
| 68 | +void HNSWIndex::set_ef_search(size_t efSearch) { |
| 69 | + idx_->hnsw.efSearch = efSearch; |
| 70 | +} |
| 71 | + |
| 72 | +void HNSWIndex::reorder_layout() { |
| 73 | + idx_->bfs_reorder(); |
| 74 | +} |
| 75 | + |
| 76 | +void HNSWIndex::reorder_layout(const std::string& mapping_file) { |
| 77 | + auto new_order = idx_->bfs_reorder(); |
| 78 | + std::ofstream ofs(mapping_file, std::ios::binary); |
| 79 | + size_t sz = new_order.size(); |
| 80 | + ofs.write(reinterpret_cast<const char*>(&sz), sizeof(sz)); |
| 81 | + ofs.write(reinterpret_cast<const char*>(new_order.data()), sz * sizeof(int)); |
| 82 | +} |
| 83 | + |
| 84 | + |
| 85 | +void HNSWIndex::write_index(const std::string& filename) const { |
| 86 | + faiss::write_index(idx_.get(), filename.c_str()); |
| 87 | +} |
| 88 | + |
| 89 | +std::shared_ptr<HNSWIndex> HNSWIndex::read_index(const std::string& filename) { |
| 90 | + faiss::Index* base = faiss::read_index(filename.c_str()); |
| 91 | + auto raw = dynamic_cast<faiss::IndexHNSWFlat*>(base); |
| 92 | + auto inst = std::make_shared<HNSWIndex>(raw->d, raw->hnsw.nb_neighbors(0)); |
| 93 | + inst->idx_.reset(raw); |
| 94 | + return inst; |
| 95 | +} |
| 96 | + |
| 97 | +double HNSWIndex::compute_recall_with_mapping( |
| 98 | + const std::vector<std::vector<faiss::idx_t>>& groundtruth, |
| 99 | + const std::vector<faiss::idx_t>& predicted_flat, |
| 100 | + size_t nq, size_t k, |
| 101 | + const std::string& mapping_file) { |
| 102 | + std::ifstream ifs(mapping_file, std::ios::binary); |
| 103 | + size_t sz; |
| 104 | + ifs.read(reinterpret_cast<char*>(&sz), sizeof(sz)); |
| 105 | + std::vector<int> old_to_new(sz); |
| 106 | + ifs.read(reinterpret_cast<char*>(old_to_new.data()), sz * sizeof(int)); |
| 107 | + std::vector<std::vector<faiss::idx_t>> pred(nq, std::vector<faiss::idx_t>(k)); |
| 108 | + for (size_t i = 0; i < nq; ++i) |
| 109 | + for (size_t j = 0; j < k; ++j) |
| 110 | + pred[i][j] = predicted_flat[i*k + j]; |
| 111 | + std::vector<double> recalls; |
| 112 | + for (size_t i = 0; i < nq; ++i) { |
| 113 | + std::set<faiss::idx_t> ts; |
| 114 | + for (size_t j = 0; j < k; ++j) { |
| 115 | + auto oldid = groundtruth[i][j]; |
| 116 | + ts.insert(old_to_new[oldid]); |
| 117 | + } |
| 118 | + std::set<faiss::idx_t> ps(pred[i].begin(), pred[i].end()); |
| 119 | + size_t hits=0; |
| 120 | + for (auto pid: ps) if (ts.count(pid)) ++hits; |
| 121 | + recalls.push_back(double(hits)/ts.size()); |
| 122 | + } |
| 123 | + return std::accumulate(recalls.begin(), recalls.end(), 0.0)/recalls.size(); |
| 124 | +} |
| 125 | + |
| 126 | +} |
0 commit comments