-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathhnsw.h
More file actions
2362 lines (2058 loc) · 104 KB
/
hnsw.h
File metadata and controls
2362 lines (2058 loc) · 104 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2006-Present, Redis Ltd.
* All rights reserved.
*
* Licensed under your choice of the Redis Source Available License 2.0
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
* GNU Affero General Public License v3 (AGPLv3).
*/
#pragma once
#include "graph_data.h"
#include "visited_nodes_handler.h"
#include "VecSim/memory/vecsim_malloc.h"
#include "VecSim/utils/vecsim_stl.h"
#include "VecSim/utils/vec_utils.h"
#include "VecSim/containers/data_block.h"
#include "VecSim/containers/raw_data_container_interface.h"
#include "VecSim/containers/data_blocks_container.h"
#include "VecSim/containers/vecsim_results_container.h"
#include "VecSim/query_result_definitions.h"
#include "VecSim/vec_sim_common.h"
#include "VecSim/vec_sim_index.h"
#include "VecSim/tombstone_interface.h"
#ifdef BUILD_TESTS
#include "hnsw_serialization_utils.h"
#include "VecSim/utils/serializer.h"
#include "hnsw_serializer.h"
#endif
#include <deque>
#include <memory>
#include <cassert>
#include <climits>
#include <queue>
#include <random>
#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <sys/resource.h>
#include <fstream>
#include <shared_mutex>
using std::pair;
typedef uint8_t elementFlags;
template <typename DistType>
using candidatesMaxHeap = vecsim_stl::max_priority_queue<DistType, idType>;
template <typename DistType>
using candidatesList = vecsim_stl::vector<pair<DistType, idType>>;
template <typename DistType>
using candidatesLabelsMaxHeap = vecsim_stl::abstract_priority_queue<DistType, labelType>;
using graphNodeType = pair<idType, unsigned short>; // represented as: (element_id, level)
////////////////////////////////////// Auxiliary HNSW structs //////////////////////////////////////
// Vectors flags (for marking a specific vector)
typedef enum {
DELETE_MARK = 0x1, // element is logically deleted, but still exists in the graph
IN_PROCESS = 0x2, // element is being inserted into the graph
} Flags;
// The state of the index and the newly stored vector to be passed to indexVector.
struct HNSWAddVectorState {
idType newElementId;
int elementMaxLevel;
idType currEntryPoint;
int currMaxLevel;
};
#pragma pack(1)
struct ElementMetaData {
labelType label;
elementFlags flags;
explicit ElementMetaData(labelType label = SIZE_MAX) noexcept
: label(label), flags(IN_PROCESS) {}
};
#pragma pack() // restore default packing
//////////////////////////////////// HNSW index implementation ////////////////////////////////////
template <typename DataType, typename DistType>
class HNSWIndex : public VecSimIndexAbstract<DataType, DistType>,
public VecSimIndexTombstone
#ifdef BUILD_TESTS
,
public HNSWSerializer
#endif
{
protected:
// Index build parameters
size_t maxElements;
size_t M;
size_t M0;
size_t efConstruction;
// Index search parameter
size_t ef;
double epsilon;
// Index meta-data (based on the data dimensionality and index parameters)
size_t elementGraphDataSize;
size_t levelDataSize;
double mult;
// Index level generator of the top level for a new element
std::default_random_engine levelGenerator;
// Index global state - these should be guarded by the indexDataGuard lock in
// multithreaded scenario.
size_t curElementCount;
idType entrypointNode;
size_t maxLevel; // this is the top level of the entry point's element
// Index data
vecsim_stl::vector<DataBlock> graphDataBlocks;
vecsim_stl::vector<ElementMetaData> idToMetaData;
// Used for marking the visited nodes in graph scans (the pool supports parallel graph scans).
// This is mutable since the object changes upon search operations as well (which are const).
mutable VisitedNodesHandlerPool visitedNodesHandlerPool;
mutable std::shared_mutex indexDataGuard;
#ifdef BUILD_TESTS
#include "VecSim/algorithms/hnsw/hnsw_base_tests_friends.h"
#include "hnsw_serializer_declarations.h"
#endif
protected:
HNSWIndex() = delete; // default constructor is disabled.
HNSWIndex(const HNSWIndex &) = delete; // default (shallow) copy constructor is disabled.
size_t getRandomLevel(double reverse_size);
template <typename Identifier> // Either idType or labelType
void processCandidate(idType curNodeId, const void *data_point, size_t layer, size_t ef,
tag_t *elements_tags, tag_t visited_tag,
vecsim_stl::abstract_priority_queue<DistType, Identifier> &top_candidates,
candidatesMaxHeap<DistType> &candidates_set, DistType &lowerBound) const;
void processCandidate_RangeSearch(
idType curNodeId, const void *data_point, size_t layer, double epsilon,
tag_t *elements_tags, tag_t visited_tag,
std::unique_ptr<vecsim_stl::abstract_results_container> &top_candidates,
candidatesMaxHeap<DistType> &candidate_set, DistType lowerBound, DistType radius) const;
candidatesMaxHeap<DistType> searchLayer(idType ep_id, const void *data_point, size_t layer,
size_t ef) const;
candidatesLabelsMaxHeap<DistType> *
searchBottomLayer_WithTimeout(idType ep_id, const void *data_point, size_t ef, size_t k,
void *timeoutCtx, VecSimQueryReply_Code *rc) const;
VecSimQueryResultContainer searchRangeBottomLayer_WithTimeout(idType ep_id,
const void *data_point,
double epsilon, DistType radius,
void *timeoutCtx,
VecSimQueryReply_Code *rc) const;
idType getNeighborsByHeuristic2(candidatesList<DistType> &top_candidates, size_t M) const;
void getNeighborsByHeuristic2(candidatesList<DistType> &top_candidates, size_t M,
vecsim_stl::vector<idType> ¬_chosen_candidates) const;
template <bool record_removed>
void getNeighborsByHeuristic2_internal(
candidatesList<DistType> &top_candidates, size_t M,
vecsim_stl::vector<idType> *removed_candidates = nullptr) const;
// Helper function for re-selecting node's neighbors which was selected as a neighbor for
// a newly inserted node. Also, responsible for mutually connect the new node and the neighbor
// (unidirectional or bidirectional connection).
// *Note that node_lock and neighbor_lock should be locked upon calling this function*
void revisitNeighborConnections(size_t level, idType new_node_id,
const std::pair<DistType, idType> &neighbor_data,
ElementLevelData &new_node_level,
ElementLevelData &neighbor_level);
idType mutuallyConnectNewElement(idType new_node_id,
candidatesMaxHeap<DistType> &top_candidates, size_t level);
void mutuallyUpdateForRepairedNode(idType node_id, size_t level,
vecsim_stl::vector<idType> &neighbors_to_remove,
vecsim_stl::vector<idType> &nodes_to_update,
vecsim_stl::vector<idType> &chosen_neighbors,
size_t max_M_cur);
template <bool running_query>
void greedySearchLevel(const void *vector_data, size_t level, idType &curObj, DistType &curDist,
void *timeoutCtx = nullptr, VecSimQueryReply_Code *rc = nullptr) const;
void repairConnectionsForDeletion(idType element_internal_id, idType neighbour_id,
ElementLevelData &node_level,
ElementLevelData &neighbor_level, size_t level,
vecsim_stl::vector<bool> &neighbours_bitmap);
void replaceEntryPoint();
void SwapLastIdWithDeletedId(idType element_internal_id, ElementGraphData *last_element,
const void *last_element_data);
/** Add vector functions */
// Protected internal function that implements generic single vector insertion.
void appendVector(const void *vector_data, labelType label);
HNSWAddVectorState storeVector(const void *vector_data, const labelType label);
// Protected internal functions for index resizing.
void growByBlock();
void shrinkByBlock();
// DO NOT USE DIRECTLY. Use `[grow|shrink]ByBlock` instead.
void resizeIndexCommon(size_t new_max_elements);
void emplaceToHeap(vecsim_stl::abstract_priority_queue<DistType, idType> &heap, DistType dist,
idType id) const;
void emplaceToHeap(vecsim_stl::abstract_priority_queue<DistType, labelType> &heap,
DistType dist, idType id) const;
void removeAndSwap(idType internalId);
size_t getVectorRelativeIndex(idType id) const { return id % this->blockSize; }
// Flagging API
template <Flags FLAG>
void markAs(idType internalId) {
__atomic_fetch_or(&idToMetaData[internalId].flags, FLAG, 0);
}
template <Flags FLAG>
void unmarkAs(idType internalId) {
__atomic_fetch_and(&idToMetaData[internalId].flags, ~FLAG, 0);
}
template <Flags FLAG>
bool isMarkedAs(idType internalId) const {
return idToMetaData[internalId].flags & FLAG;
}
void mutuallyRemoveNeighborAtPos(ElementLevelData &node_level, size_t level, idType node_id,
size_t pos);
public:
HNSWIndex(const HNSWParams *params, const AbstractIndexInitParams &abstractInitParams,
const IndexComponents<DataType, DistType> &components, size_t random_seed = 100);
virtual ~HNSWIndex();
void setEf(size_t ef);
size_t getEf() const;
void setEpsilon(double epsilon);
double getEpsilon() const;
size_t indexSize() const override;
size_t indexCapacity() const override;
/**
* Checks if the index capacity is full to hint the caller a resize is needed.
* @note Must be called with indexDataGuard locked.
*/
size_t isCapacityFull() const;
size_t getEfConstruction() const;
size_t getM() const;
size_t getMaxLevel() const;
labelType getEntryPointLabel() const;
labelType getExternalLabel(idType internal_id) const { return idToMetaData[internal_id].label; }
auto safeGetEntryPointState() const;
void lockIndexDataGuard() const;
void unlockIndexDataGuard() const;
void lockSharedIndexDataGuard() const;
void unlockSharedIndexDataGuard() const;
void lockNodeLinks(idType node_id) const;
void unlockNodeLinks(idType node_id) const;
void lockNodeLinks(ElementGraphData *node_data) const;
void unlockNodeLinks(ElementGraphData *node_data) const;
VisitedNodesHandler *getVisitedList() const;
void returnVisitedList(VisitedNodesHandler *visited_nodes_handler) const;
VecSimIndexDebugInfo debugInfo() const override;
VecSimIndexBasicInfo basicInfo() const override;
VecSimDebugInfoIterator *debugInfoIterator() const override;
bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override;
const char *getDataByInternalId(idType internal_id) const;
ElementGraphData *getGraphDataByInternalId(idType internal_id) const;
ElementLevelData &getElementLevelData(idType internal_id, size_t level) const;
ElementLevelData &getElementLevelData(ElementGraphData *element, size_t level) const;
idType searchBottomLayerEP(const void *query_data, void *timeoutCtx,
VecSimQueryReply_Code *rc) const;
void indexVector(const void *vector_data, const labelType label,
const HNSWAddVectorState &state);
VecSimQueryReply *topKQuery(const void *query_data, size_t k,
VecSimQueryParams *queryParams) const override;
VecSimQueryReply *rangeQuery(const void *query_data, double radius,
VecSimQueryParams *queryParams) const override;
void markDeletedInternal(idType internalId);
bool isMarkedDeleted(idType internalId) const;
bool isInProcess(idType internalId) const;
void unmarkInProcess(idType internalId);
HNSWAddVectorState storeNewElement(labelType label, const void *vector_data);
void removeAndSwapMarkDeletedElement(idType internalId);
void repairNodeConnections(idType node_id, size_t level);
// For prefetching only.
const ElementMetaData *getMetaDataAddress(idType internal_id) const {
return idToMetaData.data() + internal_id;
}
vecsim_stl::vector<graphNodeType> safeCollectAllNodeIncomingNeighbors(idType node_id) const;
VecSimDebugCommandCode getHNSWElementNeighbors(size_t label, int ***neighborsData);
void insertElementToGraph(idType element_id, size_t element_max_level, idType entry_point,
size_t global_max_level, const void *vector_data);
void removeVectorInPlace(idType id);
/*************************** Labels lookup API ***************************/
// Inline priority queue getter that need to be implemented by derived class.
virtual inline candidatesLabelsMaxHeap<DistType> *getNewMaxPriorityQueue() const = 0;
// Unsafe (assume index data guard is held in MT mode).
virtual vecsim_stl::vector<idType> getElementIds(size_t label) = 0;
// Remove label from the index.
virtual int removeLabel(labelType label) = 0;
#ifdef BUILD_TESTS
void fitMemory() override {
if (maxElements > 0) {
idToMetaData.shrink_to_fit();
resizeLabelLookup(idToMetaData.size());
}
}
size_t indexMetaDataCapacity() const override { return idToMetaData.capacity(); }
#endif
protected:
// inline label to id setters that need to be implemented by derived class
virtual std::unique_ptr<vecsim_stl::abstract_results_container>
getNewResultsContainer(size_t cap) const = 0;
virtual void replaceIdOfLabel(labelType label, idType new_id, idType old_id) = 0;
virtual void setVectorId(labelType label, idType id) = 0;
virtual void resizeLabelLookup(size_t new_max_elements) = 0;
};
/**
* getters and setters of index data
*/
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::setEf(size_t ef) {
this->ef = ef;
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::getEf() const {
return this->ef;
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::setEpsilon(double epsilon) {
this->epsilon = epsilon;
}
template <typename DataType, typename DistType>
double HNSWIndex<DataType, DistType>::getEpsilon() const {
return this->epsilon;
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::indexSize() const {
return this->curElementCount;
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::indexCapacity() const {
return this->maxElements;
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::isCapacityFull() const {
return indexSize() == this->maxElements;
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::getEfConstruction() const {
return this->efConstruction;
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::getM() const {
return this->M;
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::getMaxLevel() const {
return this->maxLevel;
}
template <typename DataType, typename DistType>
labelType HNSWIndex<DataType, DistType>::getEntryPointLabel() const {
if (entrypointNode != INVALID_ID)
return getExternalLabel(entrypointNode);
return SIZE_MAX;
}
template <typename DataType, typename DistType>
const char *HNSWIndex<DataType, DistType>::getDataByInternalId(idType internal_id) const {
return this->vectors->getElement(internal_id);
}
template <typename DataType, typename DistType>
ElementGraphData *
HNSWIndex<DataType, DistType>::getGraphDataByInternalId(idType internal_id) const {
return (ElementGraphData *)graphDataBlocks[internal_id / this->blockSize].getElement(
internal_id % this->blockSize);
}
template <typename DataType, typename DistType>
size_t HNSWIndex<DataType, DistType>::getRandomLevel(double reverse_size) {
std::uniform_real_distribution<double> distribution(0.0, 1.0);
double r = -log(distribution(levelGenerator)) * reverse_size;
return (size_t)r;
}
template <typename DataType, typename DistType>
ElementLevelData &HNSWIndex<DataType, DistType>::getElementLevelData(idType internal_id,
size_t level) const {
return getGraphDataByInternalId(internal_id)->getElementLevelData(level, this->levelDataSize);
}
template <typename DataType, typename DistType>
ElementLevelData &HNSWIndex<DataType, DistType>::getElementLevelData(ElementGraphData *graph_data,
size_t level) const {
return graph_data->getElementLevelData(level, this->levelDataSize);
}
template <typename DataType, typename DistType>
VisitedNodesHandler *HNSWIndex<DataType, DistType>::getVisitedList() const {
return visitedNodesHandlerPool.getAvailableVisitedNodesHandler();
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::returnVisitedList(
VisitedNodesHandler *visited_nodes_handler) const {
visitedNodesHandlerPool.returnVisitedNodesHandlerToPool(visited_nodes_handler);
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::markDeletedInternal(idType internalId) {
// Here we are holding the global index data guard (and the main index lock of the tiered index
// for shared ownership).
assert(internalId < this->curElementCount);
if (!isMarkedDeleted(internalId)) {
if (internalId == entrypointNode) {
// Internally, we hold and release the entrypoint neighbors lock.
replaceEntryPoint();
}
// Atomically set the deletion mark flag (note that other parallel threads may set the flags
// at the same time (for changing the IN_PROCESS flag).
markAs<DELETE_MARK>(internalId);
this->numMarkedDeleted++;
}
}
template <typename DataType, typename DistType>
bool HNSWIndex<DataType, DistType>::isMarkedDeleted(idType internalId) const {
return isMarkedAs<DELETE_MARK>(internalId);
}
template <typename DataType, typename DistType>
bool HNSWIndex<DataType, DistType>::isInProcess(idType internalId) const {
return isMarkedAs<IN_PROCESS>(internalId);
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::unmarkInProcess(idType internalId) {
// Atomically unset the IN_PROCESS mark flag (note that other parallel threads may set the flags
// at the same time (for marking the element with IN_PROCCESS flag).
unmarkAs<IN_PROCESS>(internalId);
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::lockIndexDataGuard() const {
indexDataGuard.lock();
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::unlockIndexDataGuard() const {
indexDataGuard.unlock();
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::lockSharedIndexDataGuard() const {
indexDataGuard.lock_shared();
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::unlockSharedIndexDataGuard() const {
indexDataGuard.unlock_shared();
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::lockNodeLinks(ElementGraphData *node_data) const {
node_data->neighborsGuard.lock();
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::unlockNodeLinks(ElementGraphData *node_data) const {
node_data->neighborsGuard.unlock();
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::lockNodeLinks(idType node_id) const {
lockNodeLinks(getGraphDataByInternalId(node_id));
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::unlockNodeLinks(idType node_id) const {
unlockNodeLinks(getGraphDataByInternalId(node_id));
}
/**
* helper functions
*/
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::emplaceToHeap(
vecsim_stl::abstract_priority_queue<DistType, idType> &heap, DistType dist, idType id) const {
heap.emplace(dist, id);
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::emplaceToHeap(
vecsim_stl::abstract_priority_queue<DistType, labelType> &heap, DistType dist,
idType id) const {
heap.emplace(dist, getExternalLabel(id));
}
// This function handles both label heaps and internal ids heaps. It uses the `emplaceToHeap`
// overloading to emplace correctly for both cases.
template <typename DataType, typename DistType>
template <typename Identifier>
void HNSWIndex<DataType, DistType>::processCandidate(
idType curNodeId, const void *query_data, size_t layer, size_t ef, tag_t *elements_tags,
tag_t visited_tag, vecsim_stl::abstract_priority_queue<DistType, Identifier> &top_candidates,
candidatesMaxHeap<DistType> &candidate_set, DistType &lowerBound) const {
ElementGraphData *cur_element = getGraphDataByInternalId(curNodeId);
lockNodeLinks(cur_element);
ElementLevelData &node_level = getElementLevelData(cur_element, layer);
linkListSize num_links = node_level.getNumLinks();
if (num_links > 0) {
const char *cur_data, *next_data;
// Pre-fetch first candidate tag address.
__builtin_prefetch(elements_tags + node_level.getLinkAtPos(0));
// Pre-fetch first candidate data block address.
next_data = getDataByInternalId(node_level.getLinkAtPos(0));
__builtin_prefetch(next_data);
for (linkListSize j = 0; j < num_links - 1; j++) {
idType candidate_id = node_level.getLinkAtPos(j);
cur_data = next_data;
// Pre-fetch next candidate tag address.
__builtin_prefetch(elements_tags + node_level.getLinkAtPos(j + 1));
// Pre-fetch next candidate data block address.
next_data = getDataByInternalId(node_level.getLinkAtPos(j + 1));
__builtin_prefetch(next_data);
if (elements_tags[candidate_id] == visited_tag || isInProcess(candidate_id))
continue;
elements_tags[candidate_id] = visited_tag;
DistType cur_dist = this->calcDistance(query_data, cur_data);
if (lowerBound > cur_dist || top_candidates.size() < ef) {
candidate_set.emplace(-cur_dist, candidate_id);
// Insert the candidate to the top candidates heap only if it is not marked as
// deleted.
if (!isMarkedDeleted(candidate_id))
emplaceToHeap(top_candidates, cur_dist, candidate_id);
if (top_candidates.size() > ef)
top_candidates.pop();
// If we have marked deleted elements, we need to verify that `top_candidates` is
// not empty (since we might have not added any non-deleted element yet).
if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
// Running the last neighbor outside the loop to avoid prefetching invalid neighbor
idType candidate_id = node_level.getLinkAtPos(num_links - 1);
cur_data = next_data;
if (elements_tags[candidate_id] != visited_tag && !isInProcess(candidate_id)) {
elements_tags[candidate_id] = visited_tag;
DistType cur_dist = this->calcDistance(query_data, cur_data);
if (lowerBound > cur_dist || top_candidates.size() < ef) {
candidate_set.emplace(-cur_dist, candidate_id);
// Insert the candidate to the top candidates heap only if it is not marked as
// deleted.
if (!isMarkedDeleted(candidate_id))
emplaceToHeap(top_candidates, cur_dist, candidate_id);
if (top_candidates.size() > ef)
top_candidates.pop();
// If we have marked deleted elements, we need to verify that `top_candidates` is
// not empty (since we might have not added any non-deleted element yet).
if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
}
unlockNodeLinks(cur_element);
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::processCandidate_RangeSearch(
idType curNodeId, const void *query_data, size_t layer, double epsilon, tag_t *elements_tags,
tag_t visited_tag, std::unique_ptr<vecsim_stl::abstract_results_container> &results,
candidatesMaxHeap<DistType> &candidate_set, DistType dyn_range, DistType radius) const {
auto *cur_element = getGraphDataByInternalId(curNodeId);
lockNodeLinks(cur_element);
ElementLevelData &node_level = getElementLevelData(cur_element, layer);
linkListSize num_links = node_level.getNumLinks();
if (num_links > 0) {
const char *cur_data, *next_data;
// Pre-fetch first candidate tag address.
__builtin_prefetch(elements_tags + node_level.getLinkAtPos(0));
// Pre-fetch first candidate data block address.
next_data = getDataByInternalId(node_level.getLinkAtPos(0));
__builtin_prefetch(next_data);
for (linkListSize j = 0; j < num_links - 1; j++) {
idType candidate_id = node_level.getLinkAtPos(j);
cur_data = next_data;
// Pre-fetch next candidate tag address.
__builtin_prefetch(elements_tags + node_level.getLinkAtPos(j + 1));
// Pre-fetch next candidate data block address.
next_data = getDataByInternalId(node_level.getLinkAtPos(j + 1));
__builtin_prefetch(next_data);
if (elements_tags[candidate_id] == visited_tag || isInProcess(candidate_id))
continue;
elements_tags[candidate_id] = visited_tag;
DistType cur_dist = this->calcDistance(query_data, cur_data);
if (cur_dist < dyn_range) {
candidate_set.emplace(-cur_dist, candidate_id);
// If the new candidate is in the requested radius, add it to the results set.
if (cur_dist <= radius && !isMarkedDeleted(candidate_id)) {
results->emplace(getExternalLabel(candidate_id), cur_dist);
}
}
}
// Running the last candidate outside the loop to avoid prefetching invalid candidate
idType candidate_id = node_level.getLinkAtPos(num_links - 1);
cur_data = next_data;
if (elements_tags[candidate_id] != visited_tag && !isInProcess(candidate_id)) {
elements_tags[candidate_id] = visited_tag;
DistType cur_dist = this->calcDistance(query_data, cur_data);
if (cur_dist < dyn_range) {
candidate_set.emplace(-cur_dist, candidate_id);
// If the new candidate is in the requested radius, add it to the results set.
if (cur_dist <= radius && !isMarkedDeleted(candidate_id)) {
results->emplace(getExternalLabel(candidate_id), cur_dist);
}
}
}
}
unlockNodeLinks(cur_element);
}
template <typename DataType, typename DistType>
candidatesMaxHeap<DistType>
HNSWIndex<DataType, DistType>::searchLayer(idType ep_id, const void *data_point, size_t layer,
size_t ef) const {
auto *visited_nodes_handler = getVisitedList();
tag_t visited_tag = visited_nodes_handler->getFreshTag();
candidatesMaxHeap<DistType> top_candidates(this->allocator);
candidatesMaxHeap<DistType> candidate_set(this->allocator);
DistType lowerBound;
if (!isMarkedDeleted(ep_id)) {
DistType dist = this->calcDistance(data_point, getDataByInternalId(ep_id));
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
candidate_set.emplace(-dist, ep_id);
} else {
lowerBound = std::numeric_limits<DistType>::max();
candidate_set.emplace(-lowerBound, ep_id);
}
visited_nodes_handler->tagNode(ep_id, visited_tag);
while (!candidate_set.empty()) {
pair<DistType, idType> curr_el_pair = candidate_set.top();
if ((-curr_el_pair.first) > lowerBound && top_candidates.size() >= ef) {
break;
}
candidate_set.pop();
processCandidate(curr_el_pair.second, data_point, layer, ef,
visited_nodes_handler->getElementsTags(), visited_tag, top_candidates,
candidate_set, lowerBound);
}
returnVisitedList(visited_nodes_handler);
return top_candidates;
}
template <typename DataType, typename DistType>
idType
HNSWIndex<DataType, DistType>::getNeighborsByHeuristic2(candidatesList<DistType> &top_candidates,
const size_t M) const {
if (top_candidates.size() < M) {
return std::min_element(top_candidates.begin(), top_candidates.end(),
[](const auto &a, const auto &b) { return a.first < b.first; })
->second;
}
getNeighborsByHeuristic2_internal<false>(top_candidates, M, nullptr);
return top_candidates.front().second;
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::getNeighborsByHeuristic2(
candidatesList<DistType> &top_candidates, const size_t M,
vecsim_stl::vector<idType> &removed_candidates) const {
getNeighborsByHeuristic2_internal<true>(top_candidates, M, &removed_candidates);
}
template <typename DataType, typename DistType>
template <bool record_removed>
void HNSWIndex<DataType, DistType>::getNeighborsByHeuristic2_internal(
candidatesList<DistType> &top_candidates, const size_t M,
vecsim_stl::vector<idType> *removed_candidates) const {
if (top_candidates.size() < M) {
return;
}
candidatesList<DistType> return_list(this->allocator);
vecsim_stl::vector<const void *> cached_vectors(this->allocator);
return_list.reserve(M);
cached_vectors.reserve(M);
if constexpr (record_removed) {
removed_candidates->reserve(top_candidates.size());
}
// Sort the candidates by their distance (we don't mind the secondary order (the internal id))
std::sort(top_candidates.begin(), top_candidates.end(),
[](const auto &a, const auto &b) { return a.first < b.first; });
auto current_pair = top_candidates.begin();
for (; current_pair != top_candidates.end() && return_list.size() < M; ++current_pair) {
DistType candidate_to_query_dist = current_pair->first;
bool good = true;
const void *curr_vector = getDataByInternalId(current_pair->second);
// a candidate is "good" to become a neighbour, unless we find
// another item that was already selected to the neighbours set which is closer
// to both q and the candidate than the distance between the candidate and q.
for (size_t i = 0; i < return_list.size(); i++) {
DistType candidate_to_selected_dist =
this->calcDistance(cached_vectors[i], curr_vector);
if (candidate_to_selected_dist < candidate_to_query_dist) {
if constexpr (record_removed) {
removed_candidates->push_back(current_pair->second);
}
good = false;
break;
}
}
if (good) {
cached_vectors.push_back(curr_vector);
return_list.push_back(*current_pair);
}
}
if constexpr (record_removed) {
for (; current_pair != top_candidates.end(); ++current_pair) {
removed_candidates->push_back(current_pair->second);
}
}
top_candidates.swap(return_list);
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::revisitNeighborConnections(
size_t level, idType new_node_id, const std::pair<DistType, idType> &neighbor_data,
ElementLevelData &new_node_level, ElementLevelData &neighbor_level) {
// Note - expect that node_lock and neighbor_lock are locked at that point.
// Collect the existing neighbors and the new node as the neighbor's neighbors candidates.
candidatesList<DistType> candidates(this->allocator);
candidates.reserve(neighbor_level.getNumLinks() + 1);
// Add the new node along with the pre-calculated distance to the current neighbor,
candidates.emplace_back(neighbor_data.first, new_node_id);
idType selected_neighbor = neighbor_data.second;
const void *selected_neighbor_data = getDataByInternalId(selected_neighbor);
for (size_t j = 0; j < neighbor_level.getNumLinks(); j++) {
candidates.emplace_back(
this->calcDistance(getDataByInternalId(neighbor_level.getLinkAtPos(j)),
selected_neighbor_data),
neighbor_level.getLinkAtPos(j));
}
// Candidates will store the newly selected neighbours (for the neighbor).
size_t max_M_cur = level ? M : M0;
vecsim_stl::vector<idType> nodes_to_update(this->allocator);
getNeighborsByHeuristic2(candidates, max_M_cur, nodes_to_update);
// Acquire all relevant locks for making the updates for the selected neighbor - all its removed
// neighbors, along with the neighbors itself and the cur node.
// but first, we release the node and neighbors lock to avoid deadlocks.
unlockNodeLinks(new_node_id);
unlockNodeLinks(selected_neighbor);
// Check if the new node was selected as a neighbor for the current neighbor.
// Make sure to add the cur node to the list of nodes to update if it was selected.
bool cur_node_chosen;
auto new_node_iter = std::find(nodes_to_update.begin(), nodes_to_update.end(), new_node_id);
if (new_node_iter != nodes_to_update.end()) {
cur_node_chosen = false;
} else {
cur_node_chosen = true;
nodes_to_update.push_back(new_node_id);
}
nodes_to_update.push_back(selected_neighbor);
std::sort(nodes_to_update.begin(), nodes_to_update.end());
size_t nodes_to_update_count = nodes_to_update.size();
for (size_t i = 0; i < nodes_to_update_count; i++) {
lockNodeLinks(nodes_to_update[i]);
}
size_t neighbour_neighbours_idx = 0;
bool update_cur_node_required = true;
for (size_t i = 0; i < neighbor_level.getNumLinks(); i++) {
if (!std::binary_search(nodes_to_update.begin(), nodes_to_update.end(),
neighbor_level.getLinkAtPos(i))) {
// The neighbor is not in the "to_update" nodes list - leave it as is.
neighbor_level.setLinkAtPos(neighbour_neighbours_idx++, neighbor_level.getLinkAtPos(i));
continue;
}
if (neighbor_level.getLinkAtPos(i) == new_node_id) {
// The new node got into the neighbor's neighbours - this means there was an update in
// another thread during between we released and reacquire the locks - leave it
// as is.
neighbor_level.setLinkAtPos(neighbour_neighbours_idx++, neighbor_level.getLinkAtPos(i));
update_cur_node_required = false;
continue;
}
// Now we know that we are looking at a node to be removed from the neighbor's neighbors.
mutuallyRemoveNeighborAtPos(neighbor_level, level, selected_neighbor, i);
}
if (update_cur_node_required && new_node_level.getNumLinks() < max_M_cur &&
!isMarkedDeleted(new_node_id) && !isMarkedDeleted(selected_neighbor)) {
// update the connection between the new node and the neighbor.
new_node_level.appendLink(selected_neighbor);
if (cur_node_chosen && neighbour_neighbours_idx < max_M_cur) {
// connection is mutual - both new node and the selected neighbor in each other's list.
neighbor_level.setLinkAtPos(neighbour_neighbours_idx++, new_node_id);
} else {
// unidirectional connection - put the new node in the neighbour's incoming edges.
neighbor_level.newIncomingUnidirectionalEdge(new_node_id);
}
}
// Done updating the neighbor's neighbors.
neighbor_level.setNumLinks(neighbour_neighbours_idx);
for (size_t i = 0; i < nodes_to_update_count; i++) {
unlockNodeLinks(nodes_to_update[i]);
}
}
template <typename DataType, typename DistType>
idType HNSWIndex<DataType, DistType>::mutuallyConnectNewElement(
idType new_node_id, candidatesMaxHeap<DistType> &top_candidates, size_t level) {
// The maximum number of neighbors allowed for an existing neighbor (not new).
size_t max_M_cur = level ? M : M0;
// Filter the top candidates to the selected neighbors by the algorithm heuristics.
// First, we need to copy the top candidates to a vector.
candidatesList<DistType> top_candidates_list(this->allocator);
top_candidates_list.insert(top_candidates_list.end(), top_candidates.begin(),
top_candidates.end());
// Use the heuristic to filter the top candidates, and get the next closest entry point.
idType next_closest_entry_point = getNeighborsByHeuristic2(top_candidates_list, M);
assert(top_candidates_list.size() <= M &&
"Should be not be more than M candidates returned by the heuristic");
auto *new_node_level = getGraphDataByInternalId(new_node_id);
ElementLevelData &new_node_level_data = getElementLevelData(new_node_level, level);
assert(new_node_level_data.getNumLinks() == 0 &&
"The newly inserted element should have blank link list");
for (auto &neighbor_data : top_candidates_list) {
idType selected_neighbor = neighbor_data.second; // neighbor's id
auto *neighbor_graph_data = getGraphDataByInternalId(selected_neighbor);
if (new_node_id < selected_neighbor) {
lockNodeLinks(new_node_level);
lockNodeLinks(neighbor_graph_data);
} else {
lockNodeLinks(neighbor_graph_data);
lockNodeLinks(new_node_level);
}
// validations...
assert(new_node_level_data.getNumLinks() <= max_M_cur && "Neighbors number exceeds limit");
assert(selected_neighbor != new_node_id && "Trying to connect an element to itself");
// Revalidate the updated count - this may change between iterations due to releasing the
// lock.
if (new_node_level_data.getNumLinks() == max_M_cur) {
// The new node cannot add more neighbors
this->log(VecSimCommonStrings::LOG_DEBUG_STRING,
"Couldn't add all chosen neighbors upon inserting a new node");
unlockNodeLinks(new_node_level);
unlockNodeLinks(neighbor_graph_data);
break;
}
// If one of the two nodes has already deleted - skip the operation.
if (isMarkedDeleted(new_node_id) || isMarkedDeleted(selected_neighbor)) {
unlockNodeLinks(new_node_level);
unlockNodeLinks(neighbor_graph_data);
continue;
}
ElementLevelData &neighbor_level_data = getElementLevelData(neighbor_graph_data, level);
// if the neighbor's neighbors list has the capacity to add the new node, make the update
// and finish.
if (neighbor_level_data.getNumLinks() < max_M_cur) {
new_node_level_data.appendLink(selected_neighbor);
neighbor_level_data.appendLink(new_node_id);
unlockNodeLinks(new_node_level);
unlockNodeLinks(neighbor_graph_data);
continue;
}
// Otherwise - we need to re-evaluate the neighbor's neighbors.
// We collect all the existing neighbors and the new node as candidates, and mutually update
// the neighbor's neighbors. We also release the acquired locks inside this call.
revisitNeighborConnections(level, new_node_id, neighbor_data, new_node_level_data,
neighbor_level_data);
}
return next_closest_entry_point;
}
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::repairConnectionsForDeletion(
idType element_internal_id, idType neighbour_id, ElementLevelData &node_level,
ElementLevelData &neighbor_level, size_t level, vecsim_stl::vector<bool> &neighbours_bitmap) {
if (isMarkedDeleted(neighbour_id)) {
// Just remove the deleted element from the neighbor's neighbors list. No need to repair as
// this change is temporary, this neighbor is about to be removed from the graph as well.
neighbor_level.removeLink(element_internal_id);
return;
}
// Add the deleted element's neighbour's original neighbors in the candidates.
vecsim_stl::vector<idType> candidate_ids(this->allocator);
candidate_ids.reserve(node_level.getNumLinks() + neighbor_level.getNumLinks());
vecsim_stl::vector<bool> neighbour_orig_neighbours_set(curElementCount, false, this->allocator);
for (size_t j = 0; j < neighbor_level.getNumLinks(); j++) {
idType cand = neighbor_level.getLinkAtPos(j);
neighbour_orig_neighbours_set[cand] = true;
// Don't add the removed element to the candidates, nor nodes that are neighbors of the
// original deleted element and will also be added to the candidates set.
if (cand != element_internal_id && !neighbours_bitmap[cand]) {
candidate_ids.push_back(cand);
}
}
// Put the deleted element's neighbours in the candidates.
for (size_t j = 0; j < node_level.getNumLinks(); j++) {
// Don't put the neighbor itself in his own candidates and nor marked deleted elements that
// were not neighbors before.
idType cand = node_level.getLinkAtPos(j);
if (cand != neighbour_id &&
(!isMarkedDeleted(cand) || neighbour_orig_neighbours_set[cand])) {
candidate_ids.push_back(cand);
}
}
size_t Mcurmax = level ? M : M0;
if (candidate_ids.size() > Mcurmax) {
// We need to filter the candidates by the heuristic.
candidatesList<DistType> candidates(this->allocator);
candidates.reserve(candidate_ids.size());
auto neighbours_data = getDataByInternalId(neighbour_id);
for (auto candidate_id : candidate_ids) {
candidates.emplace_back(
this->calcDistance(getDataByInternalId(candidate_id), neighbours_data),