Skip to content

Commit 76a828d

Browse files
committed
batch insert
1 parent 6f4c73f commit 76a828d

File tree

3 files changed

+275
-37
lines changed

3 files changed

+275
-37
lines changed

fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,18 @@
2323
import com.apple.foundationdb.annotation.API;
2424
import com.apple.foundationdb.util.LoggableException;
2525
import com.google.common.base.Suppliers;
26-
import com.google.common.collect.ImmutableList;
2726
import com.google.common.collect.Lists;
2827
import com.google.common.util.concurrent.ThreadFactoryBuilder;
2928

3029
import javax.annotation.Nonnull;
3130
import javax.annotation.Nullable;
3231
import java.util.ArrayDeque;
3332
import java.util.ArrayList;
33+
import java.util.Arrays;
3434
import java.util.Collections;
3535
import java.util.Iterator;
3636
import java.util.List;
3737
import java.util.NoSuchElementException;
38-
import java.util.Objects;
3938
import java.util.Queue;
4039
import java.util.concurrent.CompletableFuture;
4140
import java.util.concurrent.CompletionException;
@@ -1105,23 +1104,14 @@ public static <T, U> CompletableFuture<List<U>> forEach(@Nonnull final Iterable<
11051104

11061105
final int index = indexAtomic.getAndIncrement();
11071106
working.add(body.apply(currentItem)
1108-
.thenAccept(resultNode -> {
1109-
Objects.requireNonNull(resultNode);
1110-
resultArray[index] = resultNode;
1111-
}));
1107+
.thenAccept(result -> resultArray[index] = result));
11121108
}
11131109

11141110
if (working.isEmpty()) {
11151111
return AsyncUtil.READY_FALSE;
11161112
}
11171113
return AsyncUtil.whenAny(working).thenApply(ignored -> true);
1118-
}, executor).thenApply(ignored -> {
1119-
final ImmutableList.Builder<U> resultBuilder = ImmutableList.builder();
1120-
for (final Object o : resultArray) {
1121-
resultBuilder.add((U)o);
1122-
}
1123-
return resultBuilder.build();
1124-
});
1114+
}, executor).thenApply(ignored -> Arrays.asList((U[])resultArray));
11251115
}
11261116

11271117
/**

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java

Lines changed: 151 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
import java.util.function.Function;
6060
import java.util.stream.Collectors;
6161

62+
import static com.apple.foundationdb.async.MoreAsyncUtil.forEach;
63+
import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop;
64+
6265
/**
6366
* TODO.
6467
*/
@@ -70,6 +73,7 @@ public class HNSW {
7073

7174
public static final int MAX_CONCURRENT_NODE_READS = 16;
7275
public static final int MAX_CONCURRENT_NEIGHBOR_FETCHES = 3;
76+
public static final int MAX_CONCURRENT_SEARCHES = 10;
7377
@Nonnull public static final Random DEFAULT_RANDOM = new Random(0L);
7478
@Nonnull public static final Metric DEFAULT_METRIC = new Metric.EuclideanMetric();
7579
public static final int DEFAULT_M = 16;
@@ -697,12 +701,17 @@ private <R extends NodeReference, N extends NodeReference, U> CompletableFuture<
697701
@Nonnull final Iterable<R> nodeReferences,
698702
@Nonnull final Function<R, U> fetchBypassFunction,
699703
@Nonnull final BiFunction<R, Node<N>, U> biMapFunction) {
700-
return MoreAsyncUtil.forEach(nodeReferences,
704+
return forEach(nodeReferences,
701705
currentNeighborReference -> fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, layer,
702706
currentNeighborReference, fetchBypassFunction, biMapFunction), MAX_CONCURRENT_NODE_READS,
703707
getExecutor());
704708
}
705709

710+
@Nonnull
711+
public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @Nonnull final NodeReferenceWithVector nodeReferenceWithVector) {
712+
return insert(transaction, nodeReferenceWithVector.getPrimaryKey(), nodeReferenceWithVector.getVector());
713+
}
714+
706715
@Nonnull
707716
public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey,
708717
@Nonnull final Vector<Half> newVector) {
@@ -720,9 +729,9 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
720729
new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener());
721730
debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer));
722731
} else {
723-
final int entryNodeLayer = entryNodeReference.getLayer();
724-
if (insertionLayer > entryNodeLayer) {
725-
writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, entryNodeLayer);
732+
final int lMax = entryNodeReference.getLayer();
733+
if (insertionLayer > lMax) {
734+
writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, lMax);
726735
StorageAdapter.writeEntryNodeReference(transaction, getSubspace(),
727736
new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener());
728737
debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer));
@@ -757,13 +766,104 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
757766
}
758767

759768
@Nonnull
760-
private CompletableFuture<Void> insertIntoLayers(final @Nonnull Transaction transaction,
761-
final @Nonnull Tuple newPrimaryKey,
762-
final @Nonnull Vector<Half> newVector,
763-
final NodeReferenceWithDistance nodeReference, final int lMax, final int insertionLayer) {
764-
debug(l -> {
765-
l.debug("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey());
766-
});
769+
public CompletableFuture<Void> insertBatch(@Nonnull final Transaction transaction,
770+
@Nonnull List<NodeReferenceWithVector> batch) {
771+
final Metric metric = getConfig().getMetric();
772+
773+
// determine the layer each item should be inserted at
774+
final Random random = getConfig().getRandom();
775+
final List<NodeReferenceWithLayer> batchWithLayers = Lists.newArrayListWithCapacity(batch.size());
776+
for (final NodeReferenceWithVector current : batch) {
777+
batchWithLayers.add(new NodeReferenceWithLayer(current.getPrimaryKey(), current.getVector(),
778+
insertionLayer(random)));
779+
}
780+
// sort the layers in reverse order
781+
batchWithLayers.sort(Comparator.comparing(NodeReferenceWithLayer::getL).reversed());
782+
783+
return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener())
784+
.thenCompose(entryNodeReference -> {
785+
final int lMax = entryNodeReference == null ? -1 : entryNodeReference.getLayer();
786+
787+
return forEach(batchWithLayers,
788+
item -> {
789+
if (lMax == -1) {
790+
return CompletableFuture.completedFuture(null);
791+
}
792+
793+
final Vector<Half> itemVector = item.getVector();
794+
final int itemL = item.getL();
795+
796+
final NodeReferenceWithDistance initialNodeReference =
797+
new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(),
798+
entryNodeReference.getVector(),
799+
Vector.comparativeDistance(metric, entryNodeReference.getVector(), itemVector));
800+
801+
return MoreAsyncUtil.forLoop(lMax, initialNodeReference,
802+
layer -> layer > itemL,
803+
layer -> layer - 1,
804+
(layer, previousNodeReference) -> {
805+
final StorageAdapter<? extends NodeReference> storageAdapter = getStorageAdapterForLayer(layer);
806+
return greedySearchLayer(storageAdapter, transaction,
807+
previousNodeReference, layer, itemVector);
808+
}, executor);
809+
}, MAX_CONCURRENT_SEARCHES, getExecutor())
810+
.thenCompose(searchEntryReferences ->
811+
forLoop(0, entryNodeReference,
812+
index -> index < batchWithLayers.size(),
813+
index -> index + 1,
814+
(index, currentEntryNodeReference) -> {
815+
final NodeReferenceWithLayer item = batchWithLayers.get(index);
816+
final Tuple itemPrimaryKey = item.getPrimaryKey();
817+
final Vector<Half> itemVector = item.getVector();
818+
final int itemL = item.getL();
819+
820+
final EntryNodeReference newEntryNodeReference;
821+
final int currentLMax;
822+
823+
if (entryNodeReference == null) {
824+
// this is the first node
825+
writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, -1);
826+
newEntryNodeReference =
827+
new EntryNodeReference(itemPrimaryKey, itemVector, itemL);
828+
StorageAdapter.writeEntryNodeReference(transaction, getSubspace(),
829+
newEntryNodeReference, getOnWriteListener());
830+
debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL));
831+
832+
return CompletableFuture.completedFuture(newEntryNodeReference);
833+
} else {
834+
currentLMax = currentEntryNodeReference.getLayer();
835+
if (itemL > currentLMax) {
836+
writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, lMax);
837+
newEntryNodeReference =
838+
new EntryNodeReference(itemPrimaryKey, itemVector, itemL);
839+
StorageAdapter.writeEntryNodeReference(transaction, getSubspace(),
840+
newEntryNodeReference, getOnWriteListener());
841+
debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL));
842+
} else {
843+
newEntryNodeReference = entryNodeReference;
844+
}
845+
}
846+
847+
debug(l -> l.debug("entry node with key {} at layer {}",
848+
currentEntryNodeReference.getPrimaryKey(), currentLMax));
849+
850+
final var currentSearchEntry =
851+
searchEntryReferences.get(index);
852+
853+
return insertIntoLayers(transaction, itemPrimaryKey, itemVector, currentSearchEntry,
854+
lMax, itemL).thenApply(ignored -> newEntryNodeReference);
855+
}, getExecutor()));
856+
}).thenCompose(ignored -> AsyncUtil.DONE);
857+
}
858+
859+
@Nonnull
860+
private CompletableFuture<Void> insertIntoLayers(@Nonnull final Transaction transaction,
861+
@Nonnull final Tuple newPrimaryKey,
862+
@Nonnull final Vector<Half> newVector,
863+
@Nonnull final NodeReferenceWithDistance nodeReference,
864+
final int lMax,
865+
final int insertionLayer) {
866+
debug(l -> l.debug("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey()));
767867
return MoreAsyncUtil.<List<NodeReferenceWithDistance>>forLoop(Math.min(lMax, insertionLayer), ImmutableList.of(nodeReference),
768868
layer -> layer >= 0,
769869
layer -> layer - 1,
@@ -817,7 +917,7 @@ private <N extends NodeReference> CompletableFuture<List<NodeReferenceWithDistan
817917
}
818918

819919
final int currentMMax = layer == 0 ? getConfig().getMMax0() : getConfig().getMMax();
820-
return MoreAsyncUtil.forEach(selectedNeighbors,
920+
return forEach(selectedNeighbors,
821921
selectedNeighbor -> {
822922
final Node<N> selectedNeighborNode = selectedNeighbor.getNode();
823923
final NeighborsChangeSet<N> changeSet =
@@ -1110,4 +1210,43 @@ private void debug(@Nonnull final Consumer<Logger> loggerConsumer) {
11101210
loggerConsumer.accept(logger);
11111211
}
11121212
}
1213+
1214+
private static class NodeReferenceWithLayer extends NodeReferenceWithVector {
1215+
@SuppressWarnings("checkstyle:MemberName")
1216+
private final int l;
1217+
1218+
public NodeReferenceWithLayer(@Nonnull final Tuple primaryKey, @Nonnull final Vector<Half> vector,
1219+
final int l) {
1220+
super(primaryKey, vector);
1221+
this.l = l;
1222+
}
1223+
1224+
public int getL() {
1225+
return l;
1226+
}
1227+
}
1228+
1229+
private static class NodeReferenceWithSearchEntry extends NodeReferenceWithVector {
1230+
@SuppressWarnings("checkstyle:MemberName")
1231+
private final int l;
1232+
@Nonnull
1233+
private final NodeReferenceWithDistance nodeReferenceWithDistance;
1234+
1235+
public NodeReferenceWithSearchEntry(@Nonnull final Tuple primaryKey, @Nonnull final Vector<Half> vector,
1236+
final int l,
1237+
@Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance) {
1238+
super(primaryKey, vector);
1239+
this.l = l;
1240+
this.nodeReferenceWithDistance = nodeReferenceWithDistance;
1241+
}
1242+
1243+
public int getL() {
1244+
return l;
1245+
}
1246+
1247+
@Nonnull
1248+
public NodeReferenceWithDistance getNodeReferenceWithDistance() {
1249+
return nodeReferenceWithDistance;
1250+
}
1251+
}
11131252
}

0 commit comments

Comments
 (0)