59
59
import java .util .function .Function ;
60
60
import java .util .stream .Collectors ;
61
61
62
+ import static com .apple .foundationdb .async .MoreAsyncUtil .forEach ;
63
+ import static com .apple .foundationdb .async .MoreAsyncUtil .forLoop ;
64
+
62
65
/**
63
66
* TODO.
64
67
*/
@@ -70,6 +73,7 @@ public class HNSW {
70
73
71
74
public static final int MAX_CONCURRENT_NODE_READS = 16 ;
72
75
public static final int MAX_CONCURRENT_NEIGHBOR_FETCHES = 3 ;
76
+ public static final int MAX_CONCURRENT_SEARCHES = 10 ;
73
77
@ Nonnull public static final Random DEFAULT_RANDOM = new Random (0L );
74
78
@ Nonnull public static final Metric DEFAULT_METRIC = new Metric .EuclideanMetric ();
75
79
public static final int DEFAULT_M = 16 ;
@@ -697,12 +701,17 @@ private <R extends NodeReference, N extends NodeReference, U> CompletableFuture<
697
701
@ Nonnull final Iterable <R > nodeReferences ,
698
702
@ Nonnull final Function <R , U > fetchBypassFunction ,
699
703
@ Nonnull final BiFunction <R , Node <N >, U > biMapFunction ) {
700
- return MoreAsyncUtil . forEach (nodeReferences ,
704
+ return forEach (nodeReferences ,
701
705
currentNeighborReference -> fetchNodeIfNecessaryAndApply (storageAdapter , readTransaction , layer ,
702
706
currentNeighborReference , fetchBypassFunction , biMapFunction ), MAX_CONCURRENT_NODE_READS ,
703
707
getExecutor ());
704
708
}
705
709
710
+ @ Nonnull
711
+ public CompletableFuture <Void > insert (@ Nonnull final Transaction transaction , @ Nonnull final NodeReferenceWithVector nodeReferenceWithVector ) {
712
+ return insert (transaction , nodeReferenceWithVector .getPrimaryKey (), nodeReferenceWithVector .getVector ());
713
+ }
714
+
706
715
@ Nonnull
707
716
public CompletableFuture <Void > insert (@ Nonnull final Transaction transaction , @ Nonnull final Tuple newPrimaryKey ,
708
717
@ Nonnull final Vector <Half > newVector ) {
@@ -720,9 +729,9 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
720
729
new EntryNodeReference (newPrimaryKey , newVector , insertionLayer ), getOnWriteListener ());
721
730
debug (l -> l .debug ("written entry node reference with key={} on layer={}" , newPrimaryKey , insertionLayer ));
722
731
} else {
723
- final int entryNodeLayer = entryNodeReference .getLayer ();
724
- if (insertionLayer > entryNodeLayer ) {
725
- writeLonelyNodes (transaction , newPrimaryKey , newVector , insertionLayer , entryNodeLayer );
732
+ final int lMax = entryNodeReference .getLayer ();
733
+ if (insertionLayer > lMax ) {
734
+ writeLonelyNodes (transaction , newPrimaryKey , newVector , insertionLayer , lMax );
726
735
StorageAdapter .writeEntryNodeReference (transaction , getSubspace (),
727
736
new EntryNodeReference (newPrimaryKey , newVector , insertionLayer ), getOnWriteListener ());
728
737
debug (l -> l .debug ("written entry node reference with key={} on layer={}" , newPrimaryKey , insertionLayer ));
@@ -757,13 +766,104 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
757
766
}
758
767
759
768
@ Nonnull
760
- private CompletableFuture <Void > insertIntoLayers (final @ Nonnull Transaction transaction ,
761
- final @ Nonnull Tuple newPrimaryKey ,
762
- final @ Nonnull Vector <Half > newVector ,
763
- final NodeReferenceWithDistance nodeReference , final int lMax , final int insertionLayer ) {
764
- debug (l -> {
765
- l .debug ("nearest entry point at lMax={} is at key={}" , lMax , nodeReference .getPrimaryKey ());
766
- });
769
+ public CompletableFuture <Void > insertBatch (@ Nonnull final Transaction transaction ,
770
+ @ Nonnull List <NodeReferenceWithVector > batch ) {
771
+ final Metric metric = getConfig ().getMetric ();
772
+
773
+ // determine the layer each item should be inserted at
774
+ final Random random = getConfig ().getRandom ();
775
+ final List <NodeReferenceWithLayer > batchWithLayers = Lists .newArrayListWithCapacity (batch .size ());
776
+ for (final NodeReferenceWithVector current : batch ) {
777
+ batchWithLayers .add (new NodeReferenceWithLayer (current .getPrimaryKey (), current .getVector (),
778
+ insertionLayer (random )));
779
+ }
780
+ // sort the layers in reverse order
781
+ batchWithLayers .sort (Comparator .comparing (NodeReferenceWithLayer ::getL ).reversed ());
782
+
783
+ return StorageAdapter .fetchEntryNodeReference (transaction , getSubspace (), getOnReadListener ())
784
+ .thenCompose (entryNodeReference -> {
785
+ final int lMax = entryNodeReference == null ? -1 : entryNodeReference .getLayer ();
786
+
787
+ return forEach (batchWithLayers ,
788
+ item -> {
789
+ if (lMax == -1 ) {
790
+ return CompletableFuture .completedFuture (null );
791
+ }
792
+
793
+ final Vector <Half > itemVector = item .getVector ();
794
+ final int itemL = item .getL ();
795
+
796
+ final NodeReferenceWithDistance initialNodeReference =
797
+ new NodeReferenceWithDistance (entryNodeReference .getPrimaryKey (),
798
+ entryNodeReference .getVector (),
799
+ Vector .comparativeDistance (metric , entryNodeReference .getVector (), itemVector ));
800
+
801
+ return MoreAsyncUtil .forLoop (lMax , initialNodeReference ,
802
+ layer -> layer > itemL ,
803
+ layer -> layer - 1 ,
804
+ (layer , previousNodeReference ) -> {
805
+ final StorageAdapter <? extends NodeReference > storageAdapter = getStorageAdapterForLayer (layer );
806
+ return greedySearchLayer (storageAdapter , transaction ,
807
+ previousNodeReference , layer , itemVector );
808
+ }, executor );
809
+ }, MAX_CONCURRENT_SEARCHES , getExecutor ())
810
+ .thenCompose (searchEntryReferences ->
811
+ forLoop (0 , entryNodeReference ,
812
+ index -> index < batchWithLayers .size (),
813
+ index -> index + 1 ,
814
+ (index , currentEntryNodeReference ) -> {
815
+ final NodeReferenceWithLayer item = batchWithLayers .get (index );
816
+ final Tuple itemPrimaryKey = item .getPrimaryKey ();
817
+ final Vector <Half > itemVector = item .getVector ();
818
+ final int itemL = item .getL ();
819
+
820
+ final EntryNodeReference newEntryNodeReference ;
821
+ final int currentLMax ;
822
+
823
+ if (entryNodeReference == null ) {
824
+ // this is the first node
825
+ writeLonelyNodes (transaction , itemPrimaryKey , itemVector , itemL , -1 );
826
+ newEntryNodeReference =
827
+ new EntryNodeReference (itemPrimaryKey , itemVector , itemL );
828
+ StorageAdapter .writeEntryNodeReference (transaction , getSubspace (),
829
+ newEntryNodeReference , getOnWriteListener ());
830
+ debug (l -> l .debug ("written entry node reference with key={} on layer={}" , itemPrimaryKey , itemL ));
831
+
832
+ return CompletableFuture .completedFuture (newEntryNodeReference );
833
+ } else {
834
+ currentLMax = currentEntryNodeReference .getLayer ();
835
+ if (itemL > currentLMax ) {
836
+ writeLonelyNodes (transaction , itemPrimaryKey , itemVector , itemL , lMax );
837
+ newEntryNodeReference =
838
+ new EntryNodeReference (itemPrimaryKey , itemVector , itemL );
839
+ StorageAdapter .writeEntryNodeReference (transaction , getSubspace (),
840
+ newEntryNodeReference , getOnWriteListener ());
841
+ debug (l -> l .debug ("written entry node reference with key={} on layer={}" , itemPrimaryKey , itemL ));
842
+ } else {
843
+ newEntryNodeReference = entryNodeReference ;
844
+ }
845
+ }
846
+
847
+ debug (l -> l .debug ("entry node with key {} at layer {}" ,
848
+ currentEntryNodeReference .getPrimaryKey (), currentLMax ));
849
+
850
+ final var currentSearchEntry =
851
+ searchEntryReferences .get (index );
852
+
853
+ return insertIntoLayers (transaction , itemPrimaryKey , itemVector , currentSearchEntry ,
854
+ lMax , itemL ).thenApply (ignored -> newEntryNodeReference );
855
+ }, getExecutor ()));
856
+ }).thenCompose (ignored -> AsyncUtil .DONE );
857
+ }
858
+
859
+ @ Nonnull
860
+ private CompletableFuture <Void > insertIntoLayers (@ Nonnull final Transaction transaction ,
861
+ @ Nonnull final Tuple newPrimaryKey ,
862
+ @ Nonnull final Vector <Half > newVector ,
863
+ @ Nonnull final NodeReferenceWithDistance nodeReference ,
864
+ final int lMax ,
865
+ final int insertionLayer ) {
866
+ debug (l -> l .debug ("nearest entry point at lMax={} is at key={}" , lMax , nodeReference .getPrimaryKey ()));
767
867
return MoreAsyncUtil .<List <NodeReferenceWithDistance >>forLoop (Math .min (lMax , insertionLayer ), ImmutableList .of (nodeReference ),
768
868
layer -> layer >= 0 ,
769
869
layer -> layer - 1 ,
@@ -817,7 +917,7 @@ private <N extends NodeReference> CompletableFuture<List<NodeReferenceWithDistan
817
917
}
818
918
819
919
final int currentMMax = layer == 0 ? getConfig ().getMMax0 () : getConfig ().getMMax ();
820
- return MoreAsyncUtil . forEach (selectedNeighbors ,
920
+ return forEach (selectedNeighbors ,
821
921
selectedNeighbor -> {
822
922
final Node <N > selectedNeighborNode = selectedNeighbor .getNode ();
823
923
final NeighborsChangeSet <N > changeSet =
@@ -1110,4 +1210,43 @@ private void debug(@Nonnull final Consumer<Logger> loggerConsumer) {
1110
1210
loggerConsumer .accept (logger );
1111
1211
}
1112
1212
}
1213
+
1214
+ private static class NodeReferenceWithLayer extends NodeReferenceWithVector {
1215
+ @ SuppressWarnings ("checkstyle:MemberName" )
1216
+ private final int l ;
1217
+
1218
+ public NodeReferenceWithLayer (@ Nonnull final Tuple primaryKey , @ Nonnull final Vector <Half > vector ,
1219
+ final int l ) {
1220
+ super (primaryKey , vector );
1221
+ this .l = l ;
1222
+ }
1223
+
1224
+ public int getL () {
1225
+ return l ;
1226
+ }
1227
+ }
1228
+
1229
+ private static class NodeReferenceWithSearchEntry extends NodeReferenceWithVector {
1230
+ @ SuppressWarnings ("checkstyle:MemberName" )
1231
+ private final int l ;
1232
+ @ Nonnull
1233
+ private final NodeReferenceWithDistance nodeReferenceWithDistance ;
1234
+
1235
+ public NodeReferenceWithSearchEntry (@ Nonnull final Tuple primaryKey , @ Nonnull final Vector <Half > vector ,
1236
+ final int l ,
1237
+ @ Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance ) {
1238
+ super (primaryKey , vector );
1239
+ this .l = l ;
1240
+ this .nodeReferenceWithDistance = nodeReferenceWithDistance ;
1241
+ }
1242
+
1243
+ public int getL () {
1244
+ return l ;
1245
+ }
1246
+
1247
+ @ Nonnull
1248
+ public NodeReferenceWithDistance getNodeReferenceWithDistance () {
1249
+ return nodeReferenceWithDistance ;
1250
+ }
1251
+ }
1113
1252
}
0 commit comments