diff --git a/fdb-extensions/fdb-extensions.gradle b/fdb-extensions/fdb-extensions.gradle index bf281a2314..45b2e09302 100644 --- a/fdb-extensions/fdb-extensions.gradle +++ b/fdb-extensions/fdb-extensions.gradle @@ -26,6 +26,7 @@ dependencies { } api(libs.fdbJava) implementation(libs.guava) + implementation(libs.half4j) implementation(libs.slf4j.api) compileOnly(libs.jsr305) diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java index 563dec11a6..64e6d6b732 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/MoreAsyncUtil.java @@ -23,12 +23,14 @@ import com.apple.foundationdb.annotation.API; import com.apple.foundationdb.util.LoggableException; import com.google.common.base.Suppliers; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.ThreadFactoryBuilder; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -42,9 +44,13 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.IntPredicate; +import java.util.function.IntUnaryOperator; import java.util.function.Predicate; import java.util.function.Supplier; @@ -1051,6 +1057,64 @@ public static CompletableFuture swallowException(@Nonnull CompletableFutur return result; } + @Nonnull + public static CompletableFuture forLoop(final int startI, @Nullable final U startU, + @Nonnull final IntPredicate conditionPredicate, + @Nonnull final IntUnaryOperator stepFunction, + @Nonnull final BiFunction> body, + @Nonnull final Executor executor) { + final AtomicInteger loopVariableAtomic = new AtomicInteger(startI); + final AtomicReference lastResultAtomic = new AtomicReference<>(startU); + return whileTrue(() -> { + final int loopVariable = loopVariableAtomic.get(); + if (!conditionPredicate.test(loopVariable)) { + return AsyncUtil.READY_FALSE; + } + return body.apply(loopVariable, lastResultAtomic.get()) + .thenApply(result -> { + loopVariableAtomic.set(stepFunction.applyAsInt(loopVariable)); + lastResultAtomic.set(result); + return true; + }); + }, executor).thenApply(ignored -> lastResultAtomic.get()); + } + + @SuppressWarnings("unchecked") + public static CompletableFuture> forEach(@Nonnull final Iterable items, + @Nonnull final Function> body, + final int parallelism, + @Nonnull final Executor executor) { + // this deque is only modified by once upon creation + final ArrayDeque toBeProcessed = new ArrayDeque<>(); + for (final T item : items) { + toBeProcessed.addLast(item); + } + + final List> working = Lists.newArrayList(); + final AtomicInteger indexAtomic = new AtomicInteger(0); + final Object[] resultArray = new Object[toBeProcessed.size()]; + + return whileTrue(() -> { + working.removeIf(CompletableFuture::isDone); + + while (working.size() <= parallelism) { + final T currentItem = toBeProcessed.pollFirst(); + if (currentItem == null) { + break; + } + + final int index = indexAtomic.getAndIncrement(); + working.add(body.apply(currentItem) + .thenAccept(result -> resultArray[index] = result)); + } + + if (working.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + return whenAny(working).thenApply(ignored -> true); + }, executor).thenApply(ignored -> Arrays.asList((U[])resultArray)); + } + /** * A {@code Boolean} function that is always true. * @param the type of the (ignored) argument to the function diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java new file mode 100644 index 0000000000..aa062e8700 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractNode.java @@ -0,0 +1,63 @@ +/* + * AbstractNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +/** + * TODO. + * @param node type class. + */ +abstract class AbstractNode implements Node { + @Nonnull + private final Tuple primaryKey; + + @Nonnull + private final List neighbors; + + protected AbstractNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + this.primaryKey = primaryKey; + this.neighbors = ImmutableList.copyOf(neighbors); + } + + @Nonnull + @Override + public Tuple getPrimaryKey() { + return primaryKey; + } + + @Nonnull + @Override + public List getNeighbors() { + return neighbors; + } + + @Nonnull + @Override + public N getNeighbor(final int index) { + return neighbors.get(index); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java new file mode 100644 index 0000000000..e3d0c943fc --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/AbstractStorageAdapter.java @@ -0,0 +1,144 @@ +/* + * AbstractStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * Implementations and attributes common to all concrete implementations of {@link StorageAdapter}. + */ +abstract class AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(AbstractStorageAdapter.class); + + @Nonnull + private final HNSW.Config config; + @Nonnull + private final NodeFactory nodeFactory; + @Nonnull + private final Subspace subspace; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + private final Subspace dataSubspace; + + protected AbstractStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.config = config; + this.nodeFactory = nodeFactory; + this.subspace = subspace; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + this.dataSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_DATA)); + } + + @Override + @Nonnull + public HNSW.Config getConfig() { + return config; + } + + @Nonnull + @Override + public NodeFactory getNodeFactory() { + return nodeFactory; + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return getNodeFactory().getNodeKind(); + } + + @Override + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + @Override + @Nonnull + public Subspace getDataSubspace() { + return dataSubspace; + } + + @Override + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + @Override + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + @Nonnull + @Override + public CompletableFuture> fetchNode(@Nonnull final ReadTransaction readTransaction, + int layer, @Nonnull Tuple primaryKey) { + return fetchNodeInternal(readTransaction, layer, primaryKey).thenApply(this::checkNode); + } + + @Nonnull + protected abstract CompletableFuture> fetchNodeInternal(@Nonnull ReadTransaction readTransaction, + int layer, @Nonnull Tuple primaryKey); + + /** + * Method to perform basic invariant check(s) on a newly-fetched node. + * + * @param node the node to check + * was passed in + * + * @return the node that was passed in + */ + @Nullable + private Node checkNode(@Nullable final Node node) { + return node; + } + + @Override + public void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet) { + writeNodeInternal(transaction, node, layer, changeSet); + if (logger.isDebugEnabled()) { + logger.debug("written node with key={} at layer={}", node.getPrimaryKey(), layer); + } + } + + protected abstract void writeNodeInternal(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet); + +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java new file mode 100644 index 0000000000..bb8271af39 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/BaseNeighborsChangeSet.java @@ -0,0 +1,61 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.function.Predicate; + +/** + * TODO. + */ +class BaseNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private final List neighbors; + + public BaseNeighborsChangeSet(@Nonnull final List neighbors) { + this.neighbors = ImmutableList.copyOf(neighbors); + } + + @Nullable + @Override + public BaseNeighborsChangeSet getParent() { + return null; + } + + @Nonnull + @Override + public List merge() { + return neighbors; + } + + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, + @Nonnull final Predicate primaryKeyPredicate) { + // nothing to be written + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java new file mode 100644 index 0000000000..a6a28e778d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactNode.java @@ -0,0 +1,103 @@ +/* + * CompactNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * TODO. + */ +public class CompactNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public Node create(@Nonnull final Tuple primaryKey, @Nullable final Vector vector, + @Nonnull final List neighbors) { + return new CompactNode(primaryKey, Objects.requireNonNull(vector), (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.COMPACT; + } + }; + + @Nonnull + private final Vector vector; + + public CompactNode(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + this.vector = vector; + } + + @Nonnull + @Override + public NodeReference getSelfReference(@Nullable final Vector vector) { + return new NodeReference(getPrimaryKey()); + } + + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.COMPACT; + } + + @Nonnull + public Vector getVector() { + return vector; + } + + @Nonnull + @Override + public CompactNode asCompactNode() { + return this; + } + + @Nonnull + @Override + public InliningNode asInliningNode() { + throw new IllegalStateException("this is not an inlining node"); + } + + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "C[primaryKey=" + getPrimaryKey() + + ";vector=" + vector + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java new file mode 100644 index 0000000000..c3a04f86a2 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/CompactStorageAdapter.java @@ -0,0 +1,177 @@ +/* + * CompactStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * TODO. + */ +class CompactStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(CompactStorageAdapter.class); + + public CompactStorageAdapter(@Nonnull final HNSW.Config config, @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + @Nonnull + @Override + public StorageAdapter asCompactStorageAdapter() { + return this; + } + + @Nonnull + @Override + public StorageAdapter asInliningStorageAdapter() { + throw new IllegalStateException("cannot call this method on a compact storage adapter"); + } + + @Nonnull + @Override + protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] keyBytes = getDataSubspace().pack(Tuple.from(layer, primaryKey)); + + return readTransaction.get(keyBytes) + .thenApply(valueBytes -> { + if (valueBytes == null) { + throw new IllegalStateException("cannot fetch node"); + } + return nodeFromRaw(layer, primaryKey, keyBytes, valueBytes); + }); + } + + @Nonnull + private Node nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, + @Nonnull final byte[] keyBytes, @Nonnull final byte[] valueBytes) { + final Tuple nodeTuple = Tuple.fromBytes(valueBytes); + final Node node = nodeFromTuples(primaryKey, nodeTuple); + final OnReadListener onReadListener = getOnReadListener(); + onReadListener.onNodeRead(layer, node); + onReadListener.onKeyValueRead(layer, keyBytes, valueBytes); + return node; + } + + @Nonnull + private Node nodeFromTuples(@Nonnull final Tuple primaryKey, + @Nonnull final Tuple valueTuple) { + final NodeKind nodeKind = NodeKind.fromSerializedNodeKind((byte)valueTuple.getLong(0)); + Verify.verify(nodeKind == NodeKind.COMPACT); + + final Tuple vectorTuple; + final Tuple neighborsTuple; + + vectorTuple = valueTuple.getNestedTuple(1); + neighborsTuple = valueTuple.getNestedTuple(2); + return compactNodeFromTuples(primaryKey, vectorTuple, neighborsTuple); + } + + @Nonnull + private Node compactNodeFromTuples(@Nonnull final Tuple primaryKey, + @Nonnull final Tuple vectorTuple, + @Nonnull final Tuple neighborsTuple) { + final Vector vector = StorageAdapter.vectorFromTuple(vectorTuple); + final List nodeReferences = Lists.newArrayListWithExpectedSize(neighborsTuple.size()); + + for (int i = 0; i < neighborsTuple.size(); i ++) { + final Tuple neighborTuple = neighborsTuple.getNestedTuple(i); + nodeReferences.add(new NodeReference(neighborTuple)); + } + + return getNodeFactory().create(primaryKey, vector, nodeReferences); + } + + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, + final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final byte[] key = getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey())); + + final List nodeItems = Lists.newArrayListWithExpectedSize(3); + nodeItems.add(NodeKind.COMPACT.getSerialized()); + final CompactNode compactNode = node.asCompactNode(); + nodeItems.add(StorageAdapter.tupleFromVector(compactNode.getVector())); + + final Iterable neighbors = neighborsChangeSet.merge(); + + final List neighborItems = Lists.newArrayList(); + for (final NodeReference neighborReference : neighbors) { + neighborItems.add(neighborReference.getPrimaryKey()); + } + nodeItems.add(Tuple.fromList(neighborItems)); + + final Tuple nodeTuple = Tuple.fromList(nodeItems); + + final byte[] value = nodeTuple.pack(); + transaction.set(key, value); + getOnWriteListener().onNodeWritten(layer, node); + getOnWriteListener().onKeyValueWritten(layer, key, value); + + if (logger.isDebugEnabled()) { + logger.debug("written neighbors of primaryKey={}, oldSize={}, newSize={}", node.getPrimaryKey(), + node.getNeighbors().size(), neighborItems.size()); + } + } + + @Nonnull + @Override + public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, + @Nullable final Tuple lastPrimaryKey, int maxNumRead) { + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, maxNumRead, false, StreamingMode.ITERATOR); + + return AsyncUtil.mapIterable(itemsIterable, keyValue -> { + final byte[] key = keyValue.getKey(); + final byte[] value = keyValue.getValue(); + final Tuple primaryKey = getDataSubspace().unpack(key).getNestedTuple(1); + return nodeFromRaw(layer, primaryKey, key, value); + }); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java new file mode 100644 index 0000000000..e431561119 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/DeleteNeighborsChangeSet.java @@ -0,0 +1,83 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Set; +import java.util.function.Predicate; + +/** + * TODO. + */ +class DeleteNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(DeleteNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Set deletedNeighborsPrimaryKeys; + + public DeleteNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final Collection deletedNeighborsPrimaryKeys) { + this.parent = parent; + this.deletedNeighborsPrimaryKeys = ImmutableSet.copyOf(deletedNeighborsPrimaryKeys); + } + + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + @Nonnull + @Override + public Iterable merge() { + return Iterables.filter(getParent().merge(), + current -> !deletedNeighborsPrimaryKeys.contains(current.getPrimaryKey())); + } + + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, layer, node, + tuplePredicate.and(tuple -> !deletedNeighborsPrimaryKeys.contains(tuple))); + + for (final Tuple deletedNeighborPrimaryKey : deletedNeighborsPrimaryKeys) { + if (tuplePredicate.test(deletedNeighborPrimaryKey)) { + storageAdapter.deleteNeighbor(transaction, layer, node.asInliningNode(), deletedNeighborPrimaryKey); + if (logger.isDebugEnabled()) { + logger.debug("deleted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + deletedNeighborPrimaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java new file mode 100644 index 0000000000..db81252e17 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/EntryNodeReference.java @@ -0,0 +1,56 @@ +/* + * NodeWithLayer.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import java.util.Objects; + +class EntryNodeReference extends NodeReferenceWithVector { + private final int layer; + + public EntryNodeReference(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, final int layer) { + super(primaryKey, vector); + this.layer = layer; + } + + public int getLayer() { + return layer; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof EntryNodeReference)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return layer == ((EntryNodeReference)o).layer; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), layer); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java new file mode 100644 index 0000000000..fb177c9d77 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSW.java @@ -0,0 +1,1246 @@ +/* + * HNSW.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.MoreAsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.collect.Streams; +import com.google.common.collect.TreeMultimap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.apple.foundationdb.async.MoreAsyncUtil.forEach; +import static com.apple.foundationdb.async.MoreAsyncUtil.forLoop; + +/** + * TODO. + */ +@API(API.Status.EXPERIMENTAL) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSW { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(HNSW.class); + + public static final int MAX_CONCURRENT_NODE_READS = 16; + public static final int MAX_CONCURRENT_NEIGHBOR_FETCHES = 3; + public static final int MAX_CONCURRENT_SEARCHES = 10; + @Nonnull public static final Random DEFAULT_RANDOM = new Random(0L); + @Nonnull public static final Metric DEFAULT_METRIC = new Metric.EuclideanMetric(); + public static final int DEFAULT_M = 16; + public static final int DEFAULT_M_MAX = DEFAULT_M; + public static final int DEFAULT_M_MAX_0 = 2 * DEFAULT_M; + public static final int DEFAULT_EF_SEARCH = 100; + public static final int DEFAULT_EF_CONSTRUCTION = 200; + public static final boolean DEFAULT_EXTEND_CANDIDATES = false; + public static final boolean DEFAULT_KEEP_PRUNED_CONNECTIONS = false; + + @Nonnull + public static final Config DEFAULT_CONFIG = new Config(); + + @Nonnull + private final Subspace subspace; + @Nonnull + private final Executor executor; + @Nonnull + private final Config config; + @Nonnull + private final OnWriteListener onWriteListener; + @Nonnull + private final OnReadListener onReadListener; + + /** + * Configuration settings for a {@link HNSW}. + */ + @SuppressWarnings("checkstyle:MemberName") + public static class Config { + @Nonnull + private final Random random; + @Nonnull + private final Metric metric; + private final int m; + private final int mMax; + private final int mMax0; + private final int efSearch; + private final int efConstruction; + private final boolean extendCandidates; + private final boolean keepPrunedConnections; + + protected Config() { + this.random = DEFAULT_RANDOM; + this.metric = DEFAULT_METRIC; + this.m = DEFAULT_M; + this.mMax = DEFAULT_M_MAX; + this.mMax0 = DEFAULT_M_MAX_0; + this.efSearch = DEFAULT_EF_SEARCH; + this.efConstruction = DEFAULT_EF_CONSTRUCTION; + this.extendCandidates = DEFAULT_EXTEND_CANDIDATES; + this.keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; + } + + protected Config(@Nonnull final Random random, @Nonnull final Metric metric, final int m, final int mMax, + final int mMax0, final int efSearch, final int efConstruction, final boolean extendCandidates, + final boolean keepPrunedConnections) { + this.random = random; + this.metric = metric; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efSearch = efSearch; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + } + + @Nonnull + public Random getRandom() { + return random; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + public int getM() { + return m; + } + + public int getMMax() { + return mMax; + } + + public int getMMax0() { + return mMax0; + } + + public int getEfSearch() { + return efSearch; + } + + public int getEfConstruction() { + return efConstruction; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + @Nonnull + public ConfigBuilder toBuilder() { + return new ConfigBuilder(getRandom(), getMetric(), getM(), getMMax(), getMMax0(), getEfSearch(), + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + } + + @Override + @Nonnull + public String toString() { + return "Config[metric=" + getMetric() + "M=" + getM() + " , MMax=" + getMMax() + " , MMax0=" + getMMax0() + + ", efSearch=" + getEfSearch() + ", efConstruction=" + getEfConstruction() + + ", isExtendCandidates=" + isExtendCandidates() + + ", isKeepPrunedConnections=" + isKeepPrunedConnections() + "]"; + } + } + + /** + * Builder for {@link Config}. + * + * @see #newConfigBuilder + */ + @CanIgnoreReturnValue + @SuppressWarnings("checkstyle:MemberName") + public static class ConfigBuilder { + @Nonnull + private Random random = DEFAULT_RANDOM; + @Nonnull + private Metric metric = DEFAULT_METRIC; + private int m = DEFAULT_M; + private int mMax = DEFAULT_M_MAX; + private int mMax0 = DEFAULT_M_MAX_0; + private int efSearch = DEFAULT_EF_SEARCH; + private int efConstruction = DEFAULT_EF_CONSTRUCTION; + private boolean extendCandidates = DEFAULT_EXTEND_CANDIDATES; + private boolean keepPrunedConnections = DEFAULT_KEEP_PRUNED_CONNECTIONS; + + public ConfigBuilder() { + } + + public ConfigBuilder(@Nonnull Random random, @Nonnull final Metric metric, final int m, final int mMax, + final int mMax0, final int efSearch, final int efConstruction, + final boolean extendCandidates, final boolean keepPrunedConnections) { + this.random = random; + this.metric = metric; + this.m = m; + this.mMax = mMax; + this.mMax0 = mMax0; + this.efSearch = efSearch; + this.efConstruction = efConstruction; + this.extendCandidates = extendCandidates; + this.keepPrunedConnections = keepPrunedConnections; + } + + @Nonnull + public Random getRandom() { + return random; + } + + @Nonnull + public ConfigBuilder setRandom(@Nonnull final Random random) { + this.random = random; + return this; + } + + @Nonnull + public Metric getMetric() { + return metric; + } + + @Nonnull + public ConfigBuilder setMetric(@Nonnull final Metric metric) { + this.metric = metric; + return this; + } + + public int getM() { + return m; + } + + @Nonnull + public ConfigBuilder setM(final int m) { + this.m = m; + return this; + } + + public int getMMax() { + return mMax; + } + + @Nonnull + public ConfigBuilder setMMax(final int mMax) { + this.mMax = mMax; + return this; + } + + public int getMMax0() { + return mMax0; + } + + @Nonnull + public ConfigBuilder setMMax0(final int mMax0) { + this.mMax0 = mMax0; + return this; + } + + public int getEfSearch() { + return efSearch; + } + + public ConfigBuilder setEfSearch(final int efSearch) { + this.efSearch = efSearch; + return this; + } + + public int getEfConstruction() { + return efConstruction; + } + + public ConfigBuilder setEfConstruction(final int efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + public boolean isExtendCandidates() { + return extendCandidates; + } + + public ConfigBuilder setExtendCandidates(final boolean extendCandidates) { + this.extendCandidates = extendCandidates; + return this; + } + + public boolean isKeepPrunedConnections() { + return keepPrunedConnections; + } + + public ConfigBuilder setKeepPrunedConnections(final boolean keepPrunedConnections) { + this.keepPrunedConnections = keepPrunedConnections; + return this; + } + + public Config build() { + return new Config(getRandom(), getMetric(), getM(), getMMax(), getMMax0(), getEfSearch(), + getEfConstruction(), isExtendCandidates(), isKeepPrunedConnections()); + } + } + + /** + * Start building a {@link Config}. + * @return a new {@code Config} that can be altered and then built for use with a {@link HNSW} + * @see ConfigBuilder#build + */ + public static ConfigBuilder newConfigBuilder() { + return new ConfigBuilder(); + } + + /** + * TODO. + */ + public HNSW(@Nonnull final Subspace subspace, @Nonnull final Executor executor) { + this(subspace, executor, DEFAULT_CONFIG, OnWriteListener.NOOP, OnReadListener.NOOP); + } + + /** + * TODO. + */ + public HNSW(@Nonnull final Subspace subspace, + @Nonnull final Executor executor, @Nonnull final Config config, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + this.subspace = subspace; + this.executor = executor; + this.config = config; + this.onWriteListener = onWriteListener; + this.onReadListener = onReadListener; + } + + + @Nonnull + public Subspace getSubspace() { + return subspace; + } + + /** + * Get the executer used by this r-tree. + * @return executor used when running asynchronous tasks + */ + @Nonnull + public Executor getExecutor() { + return executor; + } + + /** + * Get this r-tree's configuration. + * @return r-tree configuration + */ + @Nonnull + public Config getConfig() { + return config; + } + + /** + * Get the on-write listener. + * @return the on-write listener + */ + @Nonnull + public OnWriteListener getOnWriteListener() { + return onWriteListener; + } + + /** + * Get the on-read listener. + * @return the on-read listener + */ + @Nonnull + public OnReadListener getOnReadListener() { + return onReadListener; + } + + // + // Read Path + // + + /** + * TODO. + */ + @SuppressWarnings("checkstyle:MethodName") // method name introduced by paper + @Nonnull + public CompletableFuture>> kNearestNeighborsSearch(@Nonnull final ReadTransaction readTransaction, + final int k, + final int efSearch, + @Nonnull final Vector queryVector) { + return StorageAdapter.fetchEntryNodeReference(readTransaction, getSubspace(), getOnReadListener()) + .thenCompose(entryPointAndLayer -> { + if (entryPointAndLayer == null) { + return CompletableFuture.completedFuture(null); // not a single node in the index + } + + final Metric metric = getConfig().getMetric(); + + final NodeReferenceWithDistance entryState = + new NodeReferenceWithDistance(entryPointAndLayer.getPrimaryKey(), + entryPointAndLayer.getVector(), + Vector.comparativeDistance(metric, entryPointAndLayer.getVector(), queryVector)); + + final var entryLayer = entryPointAndLayer.getLayer(); + if (entryLayer == 0) { + // entry data points to a node in layer 0 directly + return CompletableFuture.completedFuture(entryState); + } + + return forLoop(entryLayer, entryState, + layer -> layer > 0, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final var storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, readTransaction, previousNodeReference, + layer, queryVector); + }, executor); + }).thenCompose(nodeReference -> { + if (nodeReference == null) { + return CompletableFuture.completedFuture(null); + } + + final var storageAdapter = getStorageAdapterForLayer(0); + + return searchLayer(storageAdapter, readTransaction, + ImmutableList.of(nodeReference), 0, efSearch, + Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> { + // reverse the original queue + final TreeMultimap> sortedTopK = + TreeMultimap.create(Comparator.naturalOrder(), + Comparator.comparing(nodeReferenceAndNode -> nodeReferenceAndNode.getNode().getPrimaryKey())); + + for (final NodeReferenceAndNode nodeReferenceAndNode : searchResult) { + if (sortedTopK.size() < k || sortedTopK.keySet().last() > + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance()) { + sortedTopK.put(nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance(), + nodeReferenceAndNode); + } + + if (sortedTopK.size() > k) { + final Double lastKey = sortedTopK.keySet().last(); + final NodeReferenceAndNode lastNode = sortedTopK.get(lastKey).last(); + sortedTopK.remove(lastKey, lastNode); + } + } + + return ImmutableList.copyOf(sortedTopK.values()); + }); + }); + } + + @Nonnull + private CompletableFuture greedySearchLayer(@Nonnull StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final NodeReferenceWithDistance entryNeighbor, + final int layer, + @Nonnull final Vector queryVector) { + if (storageAdapter.getNodeKind() == NodeKind.INLINING) { + return greedySearchInliningLayer(storageAdapter.asInliningStorageAdapter(), readTransaction, entryNeighbor, layer, queryVector); + } else { + return searchLayer(storageAdapter, readTransaction, ImmutableList.of(entryNeighbor), layer, 1, Maps.newConcurrentMap(), queryVector) + .thenApply(searchResult -> Iterables.getOnlyElement(searchResult).getNodeReferenceWithDistance()); + } + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture greedySearchInliningLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final NodeReferenceWithDistance entryNeighbor, + final int layer, + @Nonnull final Vector queryVector) { + Verify.verify(layer > 0); + final Metric metric = getConfig().getMetric(); + final AtomicReference currentNodeReferenceAtomic = + new AtomicReference<>(entryNeighbor); + + return AsyncUtil.whileTrue(() -> onReadListener.onAsyncRead( + storageAdapter.fetchNode(readTransaction, layer, currentNodeReferenceAtomic.get().getPrimaryKey())) + .thenApply(node -> { + if (node == null) { + throw new IllegalStateException("unable to fetch node"); + } + final InliningNode inliningNode = node.asInliningNode(); + final List neighbors = inliningNode.getNeighbors(); + + final NodeReferenceWithDistance currentNodeReference = currentNodeReferenceAtomic.get(); + double minDistance = currentNodeReference.getDistance(); + + NodeReferenceWithVector nearestNeighbor = null; + for (final NodeReferenceWithVector neighbor : neighbors) { + final double distance = + Vector.comparativeDistance(metric, neighbor.getVector(), queryVector); + if (distance < minDistance) { + minDistance = distance; + nearestNeighbor = neighbor; + } + } + + if (nearestNeighbor == null) { + return false; + } + + currentNodeReferenceAtomic.set( + new NodeReferenceWithDistance(nearestNeighbor.getPrimaryKey(), nearestNeighbor.getVector(), + minDistance)); + return true; + }), executor).thenApply(ignored -> currentNodeReferenceAtomic.get()); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture>> searchLayer(@Nonnull StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Collection entryNeighbors, + final int layer, + final int efSearch, + @Nonnull final Map> nodeCache, + @Nonnull final Vector queryVector) { + final Set visited = Sets.newConcurrentHashSet(NodeReference.primaryKeys(entryNeighbors)); + final Queue candidates = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(entryNeighbors); + final Queue nearestNeighbors = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance).reversed()); + nearestNeighbors.addAll(entryNeighbors); + final Metric metric = getConfig().getMetric(); + + return AsyncUtil.whileTrue(() -> { + if (candidates.isEmpty()) { + return AsyncUtil.READY_FALSE; + } + + final NodeReferenceWithDistance candidate = candidates.poll(); + final NodeReferenceWithDistance furthestNeighbor = Objects.requireNonNull(nearestNeighbors.peek()); + + if (candidate.getDistance() > furthestNeighbor.getDistance()) { + return AsyncUtil.READY_FALSE; + } + + return fetchNodeIfNotCached(storageAdapter, readTransaction, layer, candidate, nodeCache) + .thenApply(candidateNode -> + Iterables.filter(candidateNode.getNeighbors(), + neighbor -> !visited.contains(neighbor.getPrimaryKey()))) + .thenCompose(neighborReferences -> fetchNeighborhood(storageAdapter, readTransaction, + layer, neighborReferences, nodeCache)) + .thenApply(neighborReferences -> { + for (final NodeReferenceWithVector current : neighborReferences) { + visited.add(current.getPrimaryKey()); + final double furthestDistance = + Objects.requireNonNull(nearestNeighbors.peek()).getDistance(); + + final double currentDistance = + Vector.comparativeDistance(metric, current.getVector(), queryVector); + if (currentDistance < furthestDistance || nearestNeighbors.size() < efSearch) { + final NodeReferenceWithDistance currentWithDistance = + new NodeReferenceWithDistance(current.getPrimaryKey(), current.getVector(), + currentDistance); + candidates.add(currentWithDistance); + nearestNeighbors.add(currentWithDistance); + if (nearestNeighbors.size() > efSearch) { + nearestNeighbors.poll(); + } + } + } + return true; + }); + }).thenCompose(ignored -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, nearestNeighbors, nodeCache)) + .thenApply(searchResult -> { + debug(l -> l.debug("searched layer={} for efSearch={} with result=={}", layer, efSearch, + searchResult.stream() + .map(nodeReferenceAndNode -> + "(primaryKey=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + nodeReferenceAndNode.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(",")))); + return searchResult; + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture> fetchNodeIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final NodeReference nodeReference, + @Nonnull final Map> nodeCache) { + return fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, layer, nodeReference, + nR -> nodeCache.get(nR.getPrimaryKey()), + (nR, node) -> { + nodeCache.put(nR.getPrimaryKey(), node); + return node; + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture fetchNodeIfNecessaryAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final R nodeReference, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + final U bypass = fetchBypassFunction.apply(nodeReference); + if (bypass != null) { + return CompletableFuture.completedFuture(bypass); + } + + return onReadListener.onAsyncRead( + storageAdapter.fetchNode(readTransaction, layer, nodeReference.getPrimaryKey())) + .thenApply(node -> biMapFunction.apply(nodeReference, node)); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture> fetchNeighborhood(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable neighborReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, layer, neighborReferences, + neighborReference -> { + if (neighborReference instanceof NodeReferenceWithVector) { + return (NodeReferenceWithVector)neighborReference; + } + final Node neighborNode = nodeCache.get(neighborReference.getPrimaryKey()); + if (neighborNode == null) { + return null; + } + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), neighborNode.asCompactNode().getVector()); + }, + (neighborReference, neighborNode) -> { + nodeCache.put(neighborReference.getPrimaryKey(), neighborNode); + return new NodeReferenceWithVector(neighborReference.getPrimaryKey(), neighborNode.asCompactNode().getVector()); + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture>> fetchSomeNodesIfNotCached(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Map> nodeCache) { + return fetchSomeNodesAndApply(storageAdapter, readTransaction, layer, nodeReferences, + nodeReference -> { + final Node node = nodeCache.get(nodeReference.getPrimaryKey()); + if (node == null) { + return null; + } + return new NodeReferenceAndNode<>(nodeReference, node); + }, + (nodeReferenceWithDistance, node) -> { + nodeCache.put(nodeReferenceWithDistance.getPrimaryKey(), node); + return new NodeReferenceAndNode<>(nodeReferenceWithDistance, node); + }); + } + + /** + * TODO. + */ + @Nonnull + private CompletableFuture> fetchSomeNodesAndApply(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Iterable nodeReferences, + @Nonnull final Function fetchBypassFunction, + @Nonnull final BiFunction, U> biMapFunction) { + return forEach(nodeReferences, + currentNeighborReference -> fetchNodeIfNecessaryAndApply(storageAdapter, readTransaction, layer, + currentNeighborReference, fetchBypassFunction, biMapFunction), MAX_CONCURRENT_NODE_READS, + getExecutor()); + } + + @Nonnull + public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final NodeReferenceWithVector nodeReferenceWithVector) { + return insert(transaction, nodeReferenceWithVector.getPrimaryKey(), nodeReferenceWithVector.getVector()); + } + + @Nonnull + public CompletableFuture insert(@Nonnull final Transaction transaction, @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector) { + final Metric metric = getConfig().getMetric(); + + final int insertionLayer = insertionLayer(getConfig().getRandom()); + debug(l -> l.debug("new node with key={} selected to be inserted into layer={}", newPrimaryKey, insertionLayer)); + + return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener()) + .thenApply(entryNodeReference -> { + if (entryNodeReference == null) { + // this is the first node + writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, -1); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer)); + } else { + final int lMax = entryNodeReference.getLayer(); + if (insertionLayer > lMax) { + writeLonelyNodes(transaction, newPrimaryKey, newVector, insertionLayer, lMax); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + new EntryNodeReference(newPrimaryKey, newVector, insertionLayer), getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", newPrimaryKey, insertionLayer)); + } + } + return entryNodeReference; + }).thenCompose(entryNodeReference -> { + if (entryNodeReference == null) { + return AsyncUtil.DONE; + } + + final int lMax = entryNodeReference.getLayer(); + debug(l -> l.debug("entry node with key {} at layer {}", entryNodeReference.getPrimaryKey(), + lMax)); + + final NodeReferenceWithDistance initialNodeReference = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + Vector.comparativeDistance(metric, entryNodeReference.getVector(), newVector)); + return forLoop(lMax, initialNodeReference, + layer -> layer > insertionLayer, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, transaction, + previousNodeReference, layer, newVector); + }, executor) + .thenCompose(nodeReference -> + insertIntoLayers(transaction, newPrimaryKey, newVector, nodeReference, + lMax, insertionLayer)); + }).thenCompose(ignored -> AsyncUtil.DONE); + } + + @Nonnull + public CompletableFuture insertBatch(@Nonnull final Transaction transaction, + @Nonnull List batch) { + final Metric metric = getConfig().getMetric(); + + // determine the layer each item should be inserted at + final Random random = getConfig().getRandom(); + final List batchWithLayers = Lists.newArrayListWithCapacity(batch.size()); + for (final NodeReferenceWithVector current : batch) { + batchWithLayers.add(new NodeReferenceWithLayer(current.getPrimaryKey(), current.getVector(), + insertionLayer(random))); + } + // sort the layers in reverse order + batchWithLayers.sort(Comparator.comparing(NodeReferenceWithLayer::getLayer).reversed()); + + return StorageAdapter.fetchEntryNodeReference(transaction, getSubspace(), getOnReadListener()) + .thenCompose(entryNodeReference -> { + final int lMax = entryNodeReference == null ? -1 : entryNodeReference.getLayer(); + + return forEach(batchWithLayers, + item -> { + if (lMax == -1) { + return CompletableFuture.completedFuture(null); + } + + final Vector itemVector = item.getVector(); + final int itemL = item.getLayer(); + + final NodeReferenceWithDistance initialNodeReference = + new NodeReferenceWithDistance(entryNodeReference.getPrimaryKey(), + entryNodeReference.getVector(), + Vector.comparativeDistance(metric, entryNodeReference.getVector(), itemVector)); + + return forLoop(lMax, initialNodeReference, + layer -> layer > itemL, + layer -> layer - 1, + (layer, previousNodeReference) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return greedySearchLayer(storageAdapter, transaction, + previousNodeReference, layer, itemVector); + }, executor); + }, MAX_CONCURRENT_SEARCHES, getExecutor()) + .thenCompose(searchEntryReferences -> + forLoop(0, entryNodeReference, + index -> index < batchWithLayers.size(), + index -> index + 1, + (index, currentEntryNodeReference) -> { + final NodeReferenceWithLayer item = batchWithLayers.get(index); + final Tuple itemPrimaryKey = item.getPrimaryKey(); + final Vector itemVector = item.getVector(); + final int itemL = item.getLayer(); + + final EntryNodeReference newEntryNodeReference; + final int currentLMax; + + if (entryNodeReference == null) { + // this is the first node + writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, -1); + newEntryNodeReference = + new EntryNodeReference(itemPrimaryKey, itemVector, itemL); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + newEntryNodeReference, getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL)); + + return CompletableFuture.completedFuture(newEntryNodeReference); + } else { + currentLMax = currentEntryNodeReference.getLayer(); + if (itemL > currentLMax) { + writeLonelyNodes(transaction, itemPrimaryKey, itemVector, itemL, lMax); + newEntryNodeReference = + new EntryNodeReference(itemPrimaryKey, itemVector, itemL); + StorageAdapter.writeEntryNodeReference(transaction, getSubspace(), + newEntryNodeReference, getOnWriteListener()); + debug(l -> l.debug("written entry node reference with key={} on layer={}", itemPrimaryKey, itemL)); + } else { + newEntryNodeReference = entryNodeReference; + } + } + + debug(l -> l.debug("entry node with key {} at layer {}", + currentEntryNodeReference.getPrimaryKey(), currentLMax)); + + final var currentSearchEntry = + searchEntryReferences.get(index); + + return insertIntoLayers(transaction, itemPrimaryKey, itemVector, currentSearchEntry, + lMax, itemL).thenApply(ignored -> newEntryNodeReference); + }, getExecutor())); + }).thenCompose(ignored -> AsyncUtil.DONE); + } + + @Nonnull + private CompletableFuture insertIntoLayers(@Nonnull final Transaction transaction, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector, + @Nonnull final NodeReferenceWithDistance nodeReference, + final int lMax, + final int insertionLayer) { + debug(l -> l.debug("nearest entry point at lMax={} is at key={}", lMax, nodeReference.getPrimaryKey())); + return MoreAsyncUtil.>forLoop(Math.min(lMax, insertionLayer), ImmutableList.of(nodeReference), + layer -> layer >= 0, + layer -> layer - 1, + (layer, previousNodeReferences) -> { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + return insertIntoLayer(storageAdapter, transaction, + previousNodeReferences, layer, newPrimaryKey, newVector); + }, executor).thenCompose(ignored -> AsyncUtil.DONE); + } + + @Nonnull + private CompletableFuture> insertIntoLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final List nearestNeighbors, + int layer, + @Nonnull final Tuple newPrimaryKey, + @Nonnull final Vector newVector) { + debug(l -> l.debug("begin insert key={} at layer={}", newPrimaryKey, layer)); + final Map> nodeCache = Maps.newConcurrentMap(); + + return searchLayer(storageAdapter, transaction, + nearestNeighbors, layer, config.getEfConstruction(), nodeCache, newVector) + .thenCompose(searchResult -> { + final List references = NodeReferenceAndNode.getReferences(searchResult); + + return selectNeighbors(storageAdapter, transaction, searchResult, layer, getConfig().getM(), + getConfig().isExtendCandidates(), nodeCache, newVector) + .thenCompose(selectedNeighbors -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node newNode = + nodeFactory.create(newPrimaryKey, newVector, + NodeReferenceAndNode.getReferences(selectedNeighbors)); + + final NeighborsChangeSet newNodeChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + newNode.getNeighbors()); + + storageAdapter.writeNode(transaction, newNode, layer, newNodeChangeSet); + + // create change sets for each selected neighbor and insert new node into them + final Map> neighborChangeSetMap = + Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode selectedNeighbor : selectedNeighbors) { + final NeighborsChangeSet baseSet = + new BaseNeighborsChangeSet<>(selectedNeighbor.getNode().getNeighbors()); + final NeighborsChangeSet insertSet = + new InsertNeighborsChangeSet<>(baseSet, ImmutableList.of(newNode.getSelfReference(newVector))); + neighborChangeSetMap.put(selectedNeighbor.getNode().getPrimaryKey(), + insertSet); + } + + final int currentMMax = layer == 0 ? getConfig().getMMax0() : getConfig().getMMax(); + return forEach(selectedNeighbors, + selectedNeighbor -> { + final Node selectedNeighborNode = selectedNeighbor.getNode(); + final NeighborsChangeSet changeSet = + Objects.requireNonNull(neighborChangeSetMap.get(selectedNeighborNode.getPrimaryKey())); + return pruneNeighborsIfNecessary(storageAdapter, transaction, + selectedNeighbor, layer, currentMMax, changeSet, nodeCache) + .thenApply(nodeReferencesAndNodes -> { + if (nodeReferencesAndNodes == null) { + return changeSet; + } + return resolveChangeSetFromNewNeighbors(changeSet, nodeReferencesAndNodes); + }); + }, MAX_CONCURRENT_NEIGHBOR_FETCHES, getExecutor()) + .thenApply(changeSets -> { + for (int i = 0; i < selectedNeighbors.size(); i++) { + final NodeReferenceAndNode selectedNeighbor = selectedNeighbors.get(i); + final NeighborsChangeSet changeSet = changeSets.get(i); + storageAdapter.writeNode(transaction, selectedNeighbor.getNode(), + layer, changeSet); + } + return ImmutableList.copyOf(references); + }); + }); + }).thenApply(nodeReferencesWithDistances -> { + debug(l -> l.debug("end insert key={} at layer={}", newPrimaryKey, layer)); + return nodeReferencesWithDistances; + }); + } + + private NeighborsChangeSet resolveChangeSetFromNewNeighbors(@Nonnull final NeighborsChangeSet beforeChangeSet, + @Nonnull final Iterable> afterNeighbors) { + final Map beforeNeighborsMap = Maps.newLinkedHashMap(); + for (final N n : beforeChangeSet.merge()) { + beforeNeighborsMap.put(n.getPrimaryKey(), n); + } + + final Map afterNeighborsMap = Maps.newLinkedHashMap(); + for (final NodeReferenceAndNode nodeReferenceAndNode : afterNeighbors) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + + afterNeighborsMap.put(nodeReferenceWithDistance.getPrimaryKey(), + nodeReferenceAndNode.getNode().getSelfReference(nodeReferenceWithDistance.getVector())); + } + + final ImmutableList.Builder toBeDeletedBuilder = ImmutableList.builder(); + for (final Map.Entry beforeNeighborEntry : beforeNeighborsMap.entrySet()) { + if (!afterNeighborsMap.containsKey(beforeNeighborEntry.getKey())) { + toBeDeletedBuilder.add(beforeNeighborEntry.getValue().getPrimaryKey()); + } + } + final List toBeDeleted = toBeDeletedBuilder.build(); + + final ImmutableList.Builder toBeInsertedBuilder = ImmutableList.builder(); + for (final Map.Entry afterNeighborEntry : afterNeighborsMap.entrySet()) { + if (!beforeNeighborsMap.containsKey(afterNeighborEntry.getKey())) { + toBeInsertedBuilder.add(afterNeighborEntry.getValue()); + } + } + final List toBeInserted = toBeInsertedBuilder.build(); + + NeighborsChangeSet changeSet = beforeChangeSet; + + if (!toBeDeleted.isEmpty()) { + changeSet = new DeleteNeighborsChangeSet<>(changeSet, toBeDeleted); + } + if (!toBeInserted.isEmpty()) { + changeSet = new InsertNeighborsChangeSet<>(changeSet, toBeInserted); + } + return changeSet; + } + + @Nonnull + private CompletableFuture>> pruneNeighborsIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + @Nonnull final NodeReferenceAndNode selectedNeighbor, + int layer, + int mMax, + @Nonnull final NeighborsChangeSet neighborChangeSet, + @Nonnull final Map> nodeCache) { + final Metric metric = getConfig().getMetric(); + final Node selectedNeighborNode = selectedNeighbor.getNode(); + if (selectedNeighborNode.getNeighbors().size() < mMax) { + return CompletableFuture.completedFuture(null); + } else { + debug(l -> l.debug("pruning neighborhood of key={} which has numNeighbors={} out of mMax={}", + selectedNeighborNode.getPrimaryKey(), selectedNeighborNode.getNeighbors().size(), mMax)); + return fetchNeighborhood(storageAdapter, transaction, layer, neighborChangeSet.merge(), nodeCache) + .thenCompose(nodeReferenceWithVectors -> { + final ImmutableList.Builder nodeReferencesWithDistancesBuilder = + ImmutableList.builder(); + for (final NodeReferenceWithVector nodeReferenceWithVector : nodeReferenceWithVectors) { + final var vector = nodeReferenceWithVector.getVector(); + final double distance = + Vector.comparativeDistance(metric, vector, + selectedNeighbor.getNodeReferenceWithDistance().getVector()); + nodeReferencesWithDistancesBuilder.add( + new NodeReferenceWithDistance(nodeReferenceWithVector.getPrimaryKey(), + vector, distance)); + } + return fetchSomeNodesIfNotCached(storageAdapter, transaction, layer, + nodeReferencesWithDistancesBuilder.build(), nodeCache); + }) + .thenCompose(nodeReferencesAndNodes -> + selectNeighbors(storageAdapter, transaction, + nodeReferencesAndNodes, layer, + mMax, false, nodeCache, + selectedNeighbor.getNodeReferenceWithDistance().getVector())); + } + } + + private CompletableFuture>> selectNeighbors(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Iterable> nearestNeighbors, + final int layer, + final int m, + final boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Vector vector) { + return extendCandidatesIfNecessary(storageAdapter, readTransaction, nearestNeighbors, layer, isExtendCandidates, nodeCache, vector) + .thenApply(extendedCandidates -> { + final List selected = Lists.newArrayListWithExpectedSize(m); + final Queue candidates = + new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + candidates.addAll(extendedCandidates); + final Queue discardedCandidates = + getConfig().isKeepPrunedConnections() + ? new PriorityBlockingQueue<>(config.getM(), + Comparator.comparing(NodeReferenceWithDistance::getDistance)) + : null; + + final Metric metric = getConfig().getMetric(); + + while (!candidates.isEmpty() && selected.size() < m) { + final NodeReferenceWithDistance nearestCandidate = candidates.poll(); + boolean shouldSelect = true; + for (final NodeReferenceWithDistance alreadySelected : selected) { + if (Vector.comparativeDistance(metric, nearestCandidate.getVector(), + alreadySelected.getVector()) < nearestCandidate.getDistance()) { + shouldSelect = false; + break; + } + } + if (shouldSelect) { + selected.add(nearestCandidate); + } else if (discardedCandidates != null) { + discardedCandidates.add(nearestCandidate); + } + } + + if (discardedCandidates != null) { // isKeepPrunedConnections is set to true + while (!discardedCandidates.isEmpty() && selected.size() < m) { + selected.add(discardedCandidates.poll()); + } + } + + return ImmutableList.copyOf(selected); + }).thenCompose(selectedNeighbors -> + fetchSomeNodesIfNotCached(storageAdapter, readTransaction, layer, selectedNeighbors, nodeCache)) + .thenApply(selectedNeighbors -> { + debug(l -> + l.debug("selected neighbors={}", + selectedNeighbors.stream() + .map(selectedNeighbor -> + "(primaryKey=" + selectedNeighbor.getNodeReferenceWithDistance().getPrimaryKey() + + ",distance=" + selectedNeighbor.getNodeReferenceWithDistance().getDistance() + ")") + .collect(Collectors.joining(",")))); + return selectedNeighbors; + }); + } + + private CompletableFuture> extendCandidatesIfNecessary(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final ReadTransaction readTransaction, + @Nonnull final Iterable> candidates, + int layer, + boolean isExtendCandidates, + @Nonnull final Map> nodeCache, + @Nonnull final Vector vector) { + if (isExtendCandidates) { + final Metric metric = getConfig().getMetric(); + + final Set candidatesSeen = Sets.newConcurrentHashSet(); + for (final NodeReferenceAndNode candidate : candidates) { + candidatesSeen.add(candidate.getNode().getPrimaryKey()); + } + + final ImmutableList.Builder neighborsOfCandidatesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + for (final N neighbor : candidate.getNode().getNeighbors()) { + final Tuple neighborPrimaryKey = neighbor.getPrimaryKey(); + if (!candidatesSeen.contains(neighborPrimaryKey)) { + candidatesSeen.add(neighborPrimaryKey); + neighborsOfCandidatesBuilder.add(neighbor); + } + } + } + + final Iterable neighborsOfCandidates = neighborsOfCandidatesBuilder.build(); + + return fetchNeighborhood(storageAdapter, readTransaction, layer, neighborsOfCandidates, nodeCache) + .thenApply(withVectors -> { + final ImmutableList.Builder extendedCandidatesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + extendedCandidatesBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + for (final NodeReferenceWithVector withVector : withVectors) { + final double distance = Vector.comparativeDistance(metric, withVector.getVector(), vector); + extendedCandidatesBuilder.add(new NodeReferenceWithDistance(withVector.getPrimaryKey(), + withVector.getVector(), distance)); + } + return extendedCandidatesBuilder.build(); + }); + } else { + final ImmutableList.Builder resultBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode candidate : candidates) { + resultBuilder.add(candidate.getNodeReferenceWithDistance()); + } + + return CompletableFuture.completedFuture(resultBuilder.build()); + } + } + + private void writeLonelyNodes(@Nonnull final Transaction transaction, + @Nonnull final Tuple primaryKey, + @Nonnull final Vector vector, + final int highestLayerInclusive, + final int lowestLayerExclusive) { + for (int layer = highestLayerInclusive; layer > lowestLayerExclusive; layer --) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + writeLonelyNodeOnLayer(storageAdapter, transaction, layer, primaryKey, vector); + } + } + + private void writeLonelyNodeOnLayer(@Nonnull final StorageAdapter storageAdapter, + @Nonnull final Transaction transaction, + final int layer, + @Nonnull final Tuple primaryKey, + @Nonnull final Vector vector) { + storageAdapter.writeNode(transaction, + storageAdapter.getNodeFactory() + .create(primaryKey, vector, ImmutableList.of()), layer, + new BaseNeighborsChangeSet<>(ImmutableList.of())); + debug(l -> l.debug("written lonely node at key={} on layer={}", primaryKey, layer)); + } + + public void scanLayer(@Nonnull final Database db, + final int layer, + final int batchSize, + @Nonnull final Consumer> nodeConsumer) { + final StorageAdapter storageAdapter = getStorageAdapterForLayer(layer); + final AtomicReference lastPrimaryKeyAtomic = new AtomicReference<>(); + Tuple newPrimaryKey; + do { + final Tuple lastPrimaryKey = lastPrimaryKeyAtomic.get(); + lastPrimaryKeyAtomic.set(null); + newPrimaryKey = db.run(tr -> { + Streams.stream(storageAdapter.scanLayer(tr, layer, lastPrimaryKey, batchSize)) + .forEach(node -> { + nodeConsumer.accept(node); + lastPrimaryKeyAtomic.set(node.getPrimaryKey()); + }); + return lastPrimaryKeyAtomic.get(); + }, executor); + } while (newPrimaryKey != null); + } + + @Nonnull + private StorageAdapter getStorageAdapterForLayer(final int layer) { + return false && layer > 0 + ? new InliningStorageAdapter(getConfig(), InliningNode.factory(), getSubspace(), getOnWriteListener(), getOnReadListener()) + : new CompactStorageAdapter(getConfig(), CompactNode.factory(), getSubspace(), getOnWriteListener(), getOnReadListener()); + } + + private int insertionLayer(@Nonnull final Random random) { + double lambda = 1.0 / Math.log(getConfig().getM()); + double u = 1.0 - random.nextDouble(); // Avoid log(0) + return (int) Math.floor(-Math.log(u) * lambda); + } + + @SuppressWarnings("PMD.UnusedPrivateMethod") + private void info(@Nonnull final Consumer loggerConsumer) { + if (logger.isInfoEnabled()) { + loggerConsumer.accept(logger); + } + } + + private void debug(@Nonnull final Consumer loggerConsumer) { + if (logger.isDebugEnabled()) { + loggerConsumer.accept(logger); + } + } + + private static class NodeReferenceWithLayer extends NodeReferenceWithVector { + private final int layer; + + public NodeReferenceWithLayer(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + final int layer) { + super(primaryKey, vector); + this.layer = layer; + } + + public int getLayer() { + return layer; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithLayer)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return layer == ((NodeReferenceWithLayer)o).layer; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), layer); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java new file mode 100644 index 0000000000..322b4f85b0 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/HNSWHelpers.java @@ -0,0 +1,63 @@ +/* + * HNSWHelpers.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; + +/** + * Some helper methods for {@link Node}s. + */ +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class HNSWHelpers { + private static final char[] hexArray = "0123456789ABCDEF".toCharArray(); + + private HNSWHelpers() { + // nothing + } + + /** + * Helper method to format bytes as hex strings for logging and debugging. + * @param bytes an array of bytes + * @return a {@link String} containing the hexadecimal representation of the byte array passed in + */ + @Nonnull + public static String bytesToHex(byte[] bytes) { + char[] hexChars = new char[bytes.length * 2]; + for (int j = 0; j < bytes.length; j++) { + int v = bytes[j] & 0xFF; + hexChars[j * 2] = hexArray[v >>> 4]; + hexChars[j * 2 + 1] = hexArray[v & 0x0F]; + } + return "0x" + new String(hexChars).replaceFirst("^0+(?!$)", ""); + } + + @Nonnull + public static Half halfValueOf(final double d) { + return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(d))); + } + + @Nonnull + public static Half halfValueOf(final float f) { + return Half.shortBitsToHalf(Half.halfToShortBits(Half.valueOf(f))); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java new file mode 100644 index 0000000000..48e2398950 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningNode.java @@ -0,0 +1,94 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * TODO. + */ +class InliningNode extends AbstractNode { + @Nonnull + private static final NodeFactory FACTORY = new NodeFactory<>() { + @SuppressWarnings("unchecked") + @Nonnull + @Override + public Node create(@Nonnull final Tuple primaryKey, + @Nullable final Vector vector, + @Nonnull final List neighbors) { + return new InliningNode(primaryKey, (List)neighbors); + } + + @Nonnull + @Override + public NodeKind getNodeKind() { + return NodeKind.INLINING; + } + }; + + public InliningNode(@Nonnull final Tuple primaryKey, + @Nonnull final List neighbors) { + super(primaryKey, neighbors); + } + + @Nonnull + @Override + @SpotBugsSuppressWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") + public NodeReferenceWithVector getSelfReference(@Nullable final Vector vector) { + return new NodeReferenceWithVector(getPrimaryKey(), Objects.requireNonNull(vector)); + } + + @Nonnull + @Override + public NodeKind getKind() { + return NodeKind.INLINING; + } + + @Nonnull + @Override + public CompactNode asCompactNode() { + throw new IllegalStateException("this is not a compact node"); + } + + @Nonnull + @Override + public InliningNode asInliningNode() { + return this; + } + + @Nonnull + public static NodeFactory factory() { + return FACTORY; + } + + @Override + public String toString() { + return "I[primaryKey=" + getPrimaryKey() + + ";neighbors=" + getNeighbors() + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java new file mode 100644 index 0000000000..ebbfd4d698 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java @@ -0,0 +1,181 @@ +/* + * CompactStorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.Range; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.StreamingMode; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.AsyncIterable; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * TODO. + */ +class InliningStorageAdapter extends AbstractStorageAdapter implements StorageAdapter { + public InliningStorageAdapter(@Nonnull final HNSW.Config config, + @Nonnull final NodeFactory nodeFactory, + @Nonnull final Subspace subspace, + @Nonnull final OnWriteListener onWriteListener, + @Nonnull final OnReadListener onReadListener) { + super(config, nodeFactory, subspace, onWriteListener, onReadListener); + } + + @Nonnull + @Override + public StorageAdapter asCompactStorageAdapter() { + throw new IllegalStateException("cannot call this method on an inlining storage adapter"); + } + + @Nonnull + @Override + public StorageAdapter asInliningStorageAdapter() { + return this; + } + + @Nonnull + @Override + protected CompletableFuture> fetchNodeInternal(@Nonnull final ReadTransaction readTransaction, + final int layer, + @Nonnull final Tuple primaryKey) { + final byte[] rangeKey = getNodeKey(layer, primaryKey); + + return AsyncUtil.collect(readTransaction.getRange(Range.startsWith(rangeKey), + ReadTransaction.ROW_LIMIT_UNLIMITED, false, StreamingMode.WANT_ALL), readTransaction.getExecutor()) + .thenApply(keyValues -> nodeFromRaw(layer, primaryKey, keyValues)); + } + + @Nonnull + private Node nodeFromRaw(final int layer, final @Nonnull Tuple primaryKey, final List keyValues) { + final OnReadListener onReadListener = getOnReadListener(); + + final ImmutableList.Builder nodeReferencesWithVectorBuilder = ImmutableList.builder(); + for (final KeyValue keyValue : keyValues) { + nodeReferencesWithVectorBuilder.add(neighborFromRaw(layer, keyValue.getKey(), keyValue.getValue())); + } + + final Node node = + getNodeFactory().create(primaryKey, null, nodeReferencesWithVectorBuilder.build()); + onReadListener.onNodeRead(layer, node); + return node; + } + + @Nonnull + private NodeReferenceWithVector neighborFromRaw(final int layer, final @Nonnull byte[] key, final byte[] value) { + final OnReadListener onReadListener = getOnReadListener(); + + onReadListener.onKeyValueRead(layer, key, value); + final Tuple neighborKeyTuple = getDataSubspace().unpack(key); + final Tuple neighborValueTuple = Tuple.fromBytes(value); + + final Tuple neighborPrimaryKey = neighborKeyTuple.getNestedTuple(2); // neighbor primary key + final Vector neighborVector = StorageAdapter.vectorFromTuple(neighborValueTuple); // the entire value is the vector + return new NodeReferenceWithVector(neighborPrimaryKey, neighborVector); + } + + @Override + public void writeNodeInternal(@Nonnull final Transaction transaction, @Nonnull final Node node, + final int layer, @Nonnull final NeighborsChangeSet neighborsChangeSet) { + final InliningNode inliningNode = node.asInliningNode(); + + neighborsChangeSet.writeDelta(this, transaction, layer, inliningNode, t -> true); + getOnWriteListener().onNodeWritten(layer, node); + } + + @Nonnull + private byte[] getNodeKey(final int layer, @Nonnull final Tuple primaryKey) { + return getDataSubspace().pack(Tuple.from(layer, primaryKey)); + } + + public void writeNeighbor(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Node node, @Nonnull final NodeReferenceWithVector neighbor) { + final byte[] neighborKey = getNeighborKey(layer, node, neighbor.getPrimaryKey()); + final byte[] value = StorageAdapter.tupleFromVector(neighbor.getVector()).pack(); + transaction.set(neighborKey, + value); + getOnWriteListener().onNeighborWritten(layer, node, neighbor); + getOnWriteListener().onKeyValueWritten(layer, neighborKey, value); + } + + public void deleteNeighbor(@Nonnull final Transaction transaction, final int layer, + @Nonnull final Node node, @Nonnull final Tuple neighborPrimaryKey) { + transaction.clear(getNeighborKey(layer, node, neighborPrimaryKey)); + getOnWriteListener().onNeighborDeleted(layer, node, neighborPrimaryKey); + } + + @Nonnull + private byte[] getNeighborKey(final int layer, + @Nonnull final Node node, + @Nonnull final Tuple neighborPrimaryKey) { + return getDataSubspace().pack(Tuple.from(layer, node.getPrimaryKey(), neighborPrimaryKey)); + } + + @Nonnull + @Override + public Iterable> scanLayer(@Nonnull final ReadTransaction readTransaction, int layer, + @Nullable final Tuple lastPrimaryKey, int maxNumRead) { + final byte[] layerPrefix = getDataSubspace().pack(Tuple.from(layer)); + final Range range = + lastPrimaryKey == null + ? Range.startsWith(layerPrefix) + : new Range(ByteArrayUtil.strinc(getDataSubspace().pack(Tuple.from(layer, lastPrimaryKey))), + ByteArrayUtil.strinc(layerPrefix)); + final AsyncIterable itemsIterable = + readTransaction.getRange(range, + maxNumRead, false, StreamingMode.ITERATOR); + int numRead = 0; + Tuple nodePrimaryKey = null; + ImmutableList.Builder> nodeBuilder = ImmutableList.builder(); + ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (final KeyValue item: itemsIterable) { + final NodeReferenceWithVector neighbor = + neighborFromRaw(layer, item.getKey(), item.getValue()); + final Tuple primaryKeyFromNodeReference = neighbor.getPrimaryKey(); + if (nodePrimaryKey == null) { + nodePrimaryKey = primaryKeyFromNodeReference; + } else { + if (!nodePrimaryKey.equals(primaryKeyFromNodeReference)) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + } + neighborsBuilder.add(neighbor); + numRead ++; + } + + // there may be a rest + if (numRead > 0 && numRead < maxNumRead) { + nodeBuilder.add(getNodeFactory().create(nodePrimaryKey, null, neighborsBuilder.build())); + } + + return nodeBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java new file mode 100644 index 0000000000..d68d3ae933 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InsertNeighborsChangeSet.java @@ -0,0 +1,89 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +/** + * TODO. + */ +class InsertNeighborsChangeSet implements NeighborsChangeSet { + @Nonnull + private static final Logger logger = LoggerFactory.getLogger(InsertNeighborsChangeSet.class); + + @Nonnull + private final NeighborsChangeSet parent; + + @Nonnull + private final Map insertedNeighborsMap; + + public InsertNeighborsChangeSet(@Nonnull final NeighborsChangeSet parent, + @Nonnull final List insertedNeighbors) { + this.parent = parent; + final ImmutableMap.Builder insertedNeighborsMapBuilder = ImmutableMap.builder(); + for (final N insertedNeighbor : insertedNeighbors) { + insertedNeighborsMapBuilder.put(insertedNeighbor.getPrimaryKey(), insertedNeighbor); + } + + this.insertedNeighborsMap = insertedNeighborsMapBuilder.build(); + } + + @Nonnull + @Override + public NeighborsChangeSet getParent() { + return parent; + } + + @Nonnull + @Override + public Iterable merge() { + return Iterables.concat(getParent().merge(), insertedNeighborsMap.values()); + } + + @Override + public void writeDelta(@Nonnull final InliningStorageAdapter storageAdapter, @Nonnull final Transaction transaction, + final int layer, @Nonnull final Node node, @Nonnull final Predicate tuplePredicate) { + getParent().writeDelta(storageAdapter, transaction, layer, node, + tuplePredicate.and(tuple -> !insertedNeighborsMap.containsKey(tuple))); + + for (final Map.Entry entry : insertedNeighborsMap.entrySet()) { + final Tuple primaryKey = entry.getKey(); + if (tuplePredicate.test(primaryKey)) { + storageAdapter.writeNeighbor(transaction, layer, node.asInliningNode(), + entry.getValue().asNodeReferenceWithVector()); + if (logger.isDebugEnabled()) { + logger.debug("inserted neighbor of primaryKey={} targeting primaryKey={}", node.getPrimaryKey(), + primaryKey); + } + } + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java new file mode 100644 index 0000000000..6e236a5d10 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metric.java @@ -0,0 +1,161 @@ +/* + * Metric.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; + +public interface Metric { + double distance(Double[] vector1, Double[] vector2); + + default double comparativeDistance(Double[] vector1, Double[] vector2) { + return distance(vector1, vector2); + } + + /** + * A helper method to validate that vectors can be compared. + * @param vector1 The first vector. + * @param vector2 The second vector. + */ + private static void validate(Double[] vector1, Double[] vector2) { + if (vector1 == null || vector2 == null) { + throw new IllegalArgumentException("Vectors cannot be null"); + } + if (vector1.length != vector2.length) { + throw new IllegalArgumentException( + "Vectors must have the same dimensionality. Got " + vector1.length + " and " + vector2.length + ); + } + if (vector1.length == 0) { + throw new IllegalArgumentException("Vectors cannot be empty."); + } + } + + class ManhattanMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + double sumOfAbsDiffs = 0.0; + for (int i = 0; i < vector1.length; i++) { + sumOfAbsDiffs += Math.abs(vector1[i] - vector2[i]); + } + return sumOfAbsDiffs; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class EuclideanMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + return Math.sqrt(EuclideanSquareMetric.distanceInternal(vector1, vector2)); + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class EuclideanSquareMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + return distanceInternal(vector1, vector2); + } + + private static double distanceInternal(final Double[] vector1, final Double[] vector2) { + double sumOfSquares = 0.0d; + for (int i = 0; i < vector1.length; i++) { + double diff = vector1[i] - vector2[i]; + sumOfSquares += diff * diff; + } + return sumOfSquares; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class CosineMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + normA += vector1[i] * vector1[i]; + normB += vector2[i] * vector2[i]; + } + + // Handle the case of zero-vectors to avoid division by zero + if (normA == 0.0 || normB == 0.0) { + return Double.POSITIVE_INFINITY; + } + + return 1.0d - dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class DotProductMetric implements Metric { + @Override + public double distance(final Double[] vector1, final Double[] vector2) { + throw new UnsupportedOperationException("dot product metric is not a true metric and can only be used for ranking"); + } + + @Override + public double comparativeDistance(final Double[] vector1, final Double[] vector2) { + Metric.validate(vector1, vector2); + + double product = 0.0d; + for (int i = 0; i < vector1.length; i++) { + product += vector1[i] * vector2[i]; + } + return -product; + } + + @Override + @Nonnull + public String toString() { + return this.getClass().getSimpleName(); + } + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java new file mode 100644 index 0000000000..8c30faf852 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Metrics.java @@ -0,0 +1,43 @@ +/* + * Metric.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; + +public enum Metrics { + MANHATTAN_METRIC(new Metric.ManhattanMetric()), + EUCLIDEAN_METRIC(new Metric.EuclideanMetric()), + EUCLIDEAN_SQUARE_METRIC(new Metric.EuclideanSquareMetric()), + COSINE_METRIC(new Metric.CosineMetric()), + DOT_PRODUCT_METRIC(new Metric.DotProductMetric()); + + @Nonnull + private final Metric metric; + + Metrics(@Nonnull final Metric metric) { + this.metric = metric; + } + + @Nonnull + public Metric getMetric() { + return metric; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java new file mode 100644 index 0000000000..b7f38ef1a7 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NeighborsChangeSet.java @@ -0,0 +1,42 @@ +/* + * InliningNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.function.Predicate; + +/** + * TODO. + */ +interface NeighborsChangeSet { + @Nullable + NeighborsChangeSet getParent(); + + @Nonnull + Iterable merge(); + + void writeDelta(@Nonnull InliningStorageAdapter storageAdapter, @Nonnull Transaction transaction, int layer, + @Nonnull Node node, @Nonnull Predicate primaryKeyPredicate); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java new file mode 100644 index 0000000000..f2c623f882 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Node.java @@ -0,0 +1,59 @@ +/* + * Node.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +/** + * TODO. + * @param neighbor type + */ +public interface Node { + @Nonnull + Tuple getPrimaryKey(); + + @Nonnull + N getSelfReference(@Nullable Vector vector); + + @Nonnull + List getNeighbors(); + + @Nonnull + N getNeighbor(int index); + + /** + * Return the kind of the node, i.e. {@link NodeKind#COMPACT} or {@link NodeKind#INLINING}. + * @return the kind of this node as a {@link NodeKind} + */ + @Nonnull + NodeKind getKind(); + + @Nonnull + CompactNode asCompactNode(); + + @Nonnull + InliningNode asInliningNode(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java new file mode 100644 index 0000000000..321e3f53d8 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeFactory.java @@ -0,0 +1,37 @@ +/* + * NodeFactory.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; + +public interface NodeFactory { + @Nonnull + Node create(@Nonnull Tuple primaryKey, @Nullable Vector vector, + @Nonnull List neighbors); + + @Nonnull + NodeKind getNodeKind(); +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java new file mode 100644 index 0000000000..13d71a1b9b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeKind.java @@ -0,0 +1,60 @@ +/* + * NodeKind.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; + +/** + * Enum to capture the kind of node. + */ +public enum NodeKind { + COMPACT((byte)0x00), + INLINING((byte)0x01); + + private final byte serialized; + + NodeKind(final byte serialized) { + this.serialized = serialized; + } + + public byte getSerialized() { + return serialized; + } + + @Nonnull + static NodeKind fromSerializedNodeKind(byte serializedNodeKind) { + final NodeKind nodeKind; + switch (serializedNodeKind) { + case 0x00: + nodeKind = NodeKind.COMPACT; + break; + case 0x01: + nodeKind = NodeKind.INLINING; + break; + default: + throw new IllegalArgumentException("unknown node kind"); + } + Verify.verify(nodeKind.getSerialized() == serializedNodeKind); + return nodeKind; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java new file mode 100644 index 0000000000..59b831d04d --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReference.java @@ -0,0 +1,72 @@ +/* + * NodeReference.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.google.common.collect.Streams; + +import javax.annotation.Nonnull; +import java.util.Objects; + +public class NodeReference { + @Nonnull + private final Tuple primaryKey; + + public NodeReference(@Nonnull final Tuple primaryKey) { + this.primaryKey = primaryKey; + } + + @Nonnull + public Tuple getPrimaryKey() { + return primaryKey; + } + + @Nonnull + public NodeReferenceWithVector asNodeReferenceWithVector() { + throw new IllegalStateException("method should not be called"); + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReference)) { + return false; + } + final NodeReference that = (NodeReference)o; + return Objects.equals(primaryKey, that.primaryKey); + } + + @Override + public int hashCode() { + return Objects.hashCode(primaryKey); + } + + @Override + public String toString() { + return "NR[primaryKey=" + primaryKey + "]"; + } + + @Nonnull + public static Iterable primaryKeys(@Nonnull Iterable neighbors) { + return () -> Streams.stream(neighbors) + .map(NodeReference::getPrimaryKey) + .iterator(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java new file mode 100644 index 0000000000..bbf74e864a --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceAndNode.java @@ -0,0 +1,57 @@ +/* + * NodeReferenceAndNode.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nonnull; +import java.util.List; + +public class NodeReferenceAndNode { + @Nonnull + private final NodeReferenceWithDistance nodeReferenceWithDistance; + @Nonnull + private final Node node; + + public NodeReferenceAndNode(@Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance, @Nonnull final Node node) { + this.nodeReferenceWithDistance = nodeReferenceWithDistance; + this.node = node; + } + + @Nonnull + public NodeReferenceWithDistance getNodeReferenceWithDistance() { + return nodeReferenceWithDistance; + } + + @Nonnull + public Node getNode() { + return node; + } + + @Nonnull + public static List getReferences(@Nonnull List> referencesAndNodes) { + final ImmutableList.Builder referencesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode referenceWithNode : referencesAndNodes) { + referencesBuilder.add(referenceWithNode.getNodeReferenceWithDistance()); + } + return referencesBuilder.build(); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java new file mode 100644 index 0000000000..bc9470735c --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithDistance.java @@ -0,0 +1,58 @@ +/* + * NodeReferenceWithDistance.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; + +import javax.annotation.Nonnull; +import java.util.Objects; + +public class NodeReferenceWithDistance extends NodeReferenceWithVector { + private final double distance; + + public NodeReferenceWithDistance(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector, + final double distance) { + super(primaryKey, vector); + this.distance = distance; + } + + public double getDistance() { + return distance; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithDistance)) { + return false; + } + if (!super.equals(o)) { + return false; + } + final NodeReferenceWithDistance that = (NodeReferenceWithDistance)o; + return Double.compare(distance, that.distance) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), distance); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java new file mode 100644 index 0000000000..e21b221622 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/NodeReferenceWithVector.java @@ -0,0 +1,76 @@ +/* + * NodeReferenceWithVector.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Objects; + +import javax.annotation.Nonnull; + +public class NodeReferenceWithVector extends NodeReference { + @Nonnull + private final Vector vector; + + public NodeReferenceWithVector(@Nonnull final Tuple primaryKey, @Nonnull final Vector vector) { + super(primaryKey); + this.vector = vector; + } + + @Nonnull + public Vector getVector() { + return vector; + } + + @Nonnull + public Vector getDoubleVector() { + return vector.toDoubleVector(); + } + + @Nonnull + @Override + public NodeReferenceWithVector asNodeReferenceWithVector() { + return this; + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof NodeReferenceWithVector)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return Objects.equal(vector, ((NodeReferenceWithVector)o).vector); + } + + @Override + public int hashCode() { + return Objects.hashCode(super.hashCode(), vector); + } + + @Override + public String toString() { + return "NRV[primaryKey=" + getPrimaryKey() + + ";vector=" + vector.toString(3) + + "]"; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java new file mode 100644 index 0000000000..753648cf77 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnReadListener.java @@ -0,0 +1,46 @@ +/* + * OnReadListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import javax.annotation.Nonnull; +import java.util.concurrent.CompletableFuture; + +/** + * Function interface for a call back whenever we read the slots for a node. + */ +public interface OnReadListener { + OnReadListener NOOP = new OnReadListener() { + }; + + default CompletableFuture> onAsyncRead(@Nonnull CompletableFuture> future) { + return future; + } + + default void onNodeRead(int layer, @Nonnull Node node) { + // nothing + } + + default void onKeyValueRead(int layer, + @Nonnull byte[] key, + @Nonnull byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java new file mode 100644 index 0000000000..fd4a096208 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/OnWriteListener.java @@ -0,0 +1,49 @@ +/* + * OnWriteListener.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; + +/** + * Function interface for a call back whenever we read the slots for a node. + */ +public interface OnWriteListener { + OnWriteListener NOOP = new OnWriteListener() { + }; + + default void onNodeWritten(final int layer, @Nonnull final Node node) { + // nothing + } + + default void onNeighborWritten(final int layer, @Nonnull final Node node, final NodeReference neighbor) { + // nothing + } + + default void onNeighborDeleted(final int layer, @Nonnull final Node node, @Nonnull Tuple neighborPrimaryKey) { + // nothing + } + + default void onKeyValueWritten(final int layer, @Nonnull byte[] key, @Nonnull byte[] value) { + // nothing + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java new file mode 100644 index 0000000000..82bd281c62 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/StorageAdapter.java @@ -0,0 +1,184 @@ +/* + * StorageAdapter.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Verify; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.concurrent.CompletableFuture; + +/** + * Storage adapter used for serialization and deserialization of nodes. + */ +interface StorageAdapter { + byte SUBSPACE_PREFIX_ENTRY_NODE = 0x01; + byte SUBSPACE_PREFIX_DATA = 0x02; + + /** + * Get the {@link HNSW.Config} associated with this storage adapter. + * @return the configuration used by this storage adapter + */ + @Nonnull + HNSW.Config getConfig(); + + @Nonnull + NodeFactory getNodeFactory(); + + @Nonnull + NodeKind getNodeKind(); + + @Nonnull + StorageAdapter asCompactStorageAdapter(); + + @Nonnull + StorageAdapter asInliningStorageAdapter(); + + /** + * Get the subspace used to store this r-tree. + * + * @return r-tree subspace + */ + @Nonnull + Subspace getSubspace(); + + @Nonnull + Subspace getDataSubspace(); + + /** + * Get the on-write listener. + * + * @return the on-write listener. + */ + @Nonnull + OnWriteListener getOnWriteListener(); + + /** + * Get the on-read listener. + * + * @return the on-read listener. + */ + @Nonnull + OnReadListener getOnReadListener(); + + @Nonnull + CompletableFuture> fetchNode(@Nonnull ReadTransaction readTransaction, + int layer, + @Nonnull Tuple primaryKey); + + void writeNode(@Nonnull Transaction transaction, @Nonnull Node node, int layer, + @Nonnull NeighborsChangeSet changeSet); + + Iterable> scanLayer(@Nonnull ReadTransaction readTransaction, int layer, @Nullable Tuple lastPrimaryKey, + int maxNumRead); + + @Nonnull + static CompletableFuture fetchEntryNodeReference(@Nonnull final ReadTransaction readTransaction, + @Nonnull final Subspace subspace, + @Nonnull final OnReadListener onReadListener) { + final Subspace entryNodeSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_ENTRY_NODE)); + final byte[] key = entryNodeSubspace.pack(); + + return readTransaction.get(key) + .thenApply(valueBytes -> { + if (valueBytes == null) { + return null; // not a single node in the index + } + onReadListener.onKeyValueRead(-1, key, valueBytes); + + final Tuple entryTuple = Tuple.fromBytes(valueBytes); + final int lMax = (int)entryTuple.getLong(0); + final Tuple primaryKey = entryTuple.getNestedTuple(1); + final Tuple vectorTuple = entryTuple.getNestedTuple(2); + return new EntryNodeReference(primaryKey, StorageAdapter.vectorFromTuple(vectorTuple), lMax); + }); + } + + static void writeEntryNodeReference(@Nonnull final Transaction transaction, + @Nonnull final Subspace subspace, + @Nonnull final EntryNodeReference entryNodeReference, + @Nonnull final OnWriteListener onWriteListener) { + final Subspace entryNodeSubspace = subspace.subspace(Tuple.from(SUBSPACE_PREFIX_ENTRY_NODE)); + final byte[] key = entryNodeSubspace.pack(); + final byte[] value = Tuple.from(entryNodeReference.getLayer(), + entryNodeReference.getPrimaryKey(), + StorageAdapter.tupleFromVector(entryNodeReference.getVector())).pack(); + transaction.set(key, + value); + onWriteListener.onKeyValueWritten(entryNodeReference.getLayer(), key, value); + } + + @Nonnull + static Vector.HalfVector vectorFromTuple(final Tuple vectorTuple) { + return vectorFromBytes(vectorTuple.getBytes(0)); + } + + @Nonnull + static Vector.HalfVector vectorFromBytes(final byte[] vectorBytes) { + final int bytesLength = vectorBytes.length; + Verify.verify(bytesLength % 2 == 0); + final int componentSize = bytesLength >>> 1; + final Half[] vectorHalfs = new Half[componentSize]; + for (int i = 0; i < componentSize; i ++) { + vectorHalfs[i] = Half.shortBitsToHalf(shortFromBytes(vectorBytes, i << 1)); + } + return new Vector.HalfVector(vectorHalfs); + } + + + @Nonnull + @SuppressWarnings("PrimitiveArrayArgumentToVarargsMethod") + static Tuple tupleFromVector(final Vector vector) { + return Tuple.from(bytesFromVector(vector)); + } + + @Nonnull + static byte[] bytesFromVector(final Vector vector) { + final byte[] vectorBytes = new byte[2 * vector.size()]; + for (int i = 0; i < vector.size(); i ++) { + final byte[] componentBytes = bytesFromShort(Half.halfToShortBits(vector.getComponent(i))); + final int indexTimesTwo = i << 1; + vectorBytes[indexTimesTwo] = componentBytes[0]; + vectorBytes[indexTimesTwo + 1] = componentBytes[1]; + } + return vectorBytes; + } + + static short shortFromBytes(final byte[] bytes, final int offset) { + Verify.verify(offset % 2 == 0); + int high = bytes[offset] & 0xFF; // Convert to unsigned int + int low = bytes[offset + 1] & 0xFF; + + return (short) ((high << 8) | low); + } + + static byte[] bytesFromShort(final short value) { + byte[] result = new byte[2]; + result[0] = (byte) ((value >> 8) & 0xFF); // high byte first + result[1] = (byte) (value & 0xFF); // low byte second + return result; + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java new file mode 100644 index 0000000000..bfa179ea2b --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/Vector.java @@ -0,0 +1,204 @@ +/* + * HNSWHelpers.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.christianheina.langx.half4j.Half; +import com.google.common.base.Suppliers; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * TODO. + * @param representation type + */ +public abstract class Vector { + @Nonnull + protected R[] data; + @Nonnull + protected Supplier hashCodeSupplier; + + public Vector(@Nonnull final R[] data) { + this.data = data; + this.hashCodeSupplier = Suppliers.memoize(this::computeHashCode); + } + + public int size() { + return data.length; + } + + @Nonnull + R getComponent(int dimension) { + return data[dimension]; + } + + @Nonnull + public R[] getData() { + return data; + } + + @Nonnull + public abstract byte[] getRawData(); + + @Nonnull + public abstract Vector toHalfVector(); + + @Nonnull + public abstract DoubleVector toDoubleVector(); + + @Override + public boolean equals(final Object o) { + if (!(o instanceof Vector)) { + return false; + } + final Vector vector = (Vector)o; + return Objects.deepEquals(data, vector.data); + } + + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } + + private int computeHashCode() { + return Arrays.hashCode(data); + } + + @Override + public String toString() { + return toString(3); + } + + public String toString(final int limitDimensions) { + if (limitDimensions < data.length) { + return "[" + Arrays.stream(Arrays.copyOfRange(data, 0, limitDimensions)) + .map(String::valueOf) + .collect(Collectors.joining(",")) + ", ...]"; + } else { + return "[" + Arrays.stream(data) + .map(String::valueOf) + .collect(Collectors.joining(",")) + "]"; + } + } + + public static class HalfVector extends Vector { + @Nonnull + private final Supplier toDoubleVectorSupplier; + @Nonnull + private final Supplier toRawDataSupplier; + + public HalfVector(@Nonnull final Half[] data) { + super(data); + this.toDoubleVectorSupplier = Suppliers.memoize(this::computeDoubleVector); + this.toRawDataSupplier = Suppliers.memoize(this::computeRawData); + } + + @Nonnull + @Override + public Vector toHalfVector() { + return this; + } + + @Nonnull + @Override + public DoubleVector toDoubleVector() { + return toDoubleVectorSupplier.get(); + } + + @Nonnull + public DoubleVector computeDoubleVector() { + Double[] result = new Double[data.length]; + for (int i = 0; i < data.length; i ++) { + result[i] = data[i].doubleValue(); + } + return new DoubleVector(result); + } + + @Nonnull + @Override + public byte[] getRawData() { + return toRawDataSupplier.get(); + } + + @Nonnull + private byte[] computeRawData() { + return StorageAdapter.bytesFromVector(this); + } + + @Nonnull + public static HalfVector halfVectorFromBytes(@Nonnull final byte[] vectorBytes) { + return StorageAdapter.vectorFromBytes(vectorBytes); + } + } + + public static class DoubleVector extends Vector { + @Nonnull + private final Supplier toHalfVectorSupplier; + + public DoubleVector(@Nonnull final Double[] data) { + super(data); + this.toHalfVectorSupplier = Suppliers.memoize(this::computeHalfVector); + } + + @Nonnull + @Override + public HalfVector toHalfVector() { + return toHalfVectorSupplier.get(); + } + + @Nonnull + public HalfVector computeHalfVector() { + Half[] result = new Half[data.length]; + for (int i = 0; i < data.length; i ++) { + result[i] = Half.valueOf(data[i]); + } + return new HalfVector(result); + } + + @Nonnull + @Override + public DoubleVector toDoubleVector() { + return this; + } + + @Nonnull + @Override + public byte[] getRawData() { + // TODO + throw new UnsupportedOperationException("not implemented yet"); + } + } + + static double distance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.distance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); + } + + static double comparativeDistance(@Nonnull Metric metric, + @Nonnull final Vector vector1, + @Nonnull final Vector vector2) { + return metric.comparativeDistance(vector1.toDoubleVector().getData(), vector2.toDoubleVector().getData()); + } +} diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java new file mode 100644 index 0000000000..5565b7f9f6 --- /dev/null +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/package-info.java @@ -0,0 +1,24 @@ +/* + * package-info.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Classes and interfaces related to the Hilbert R-tree implementation. + */ +package com.apple.foundationdb.async.hnsw; diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java index db4e4cf636..a11ac8b462 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/NodeHelpers.java @@ -1,5 +1,5 @@ /* - * NodeHelpers.java + * HNSWHelpers.java * * This source file is part of the FoundationDB open source project * diff --git a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java index f60c17da63..2623cff1dc 100644 --- a/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java +++ b/fdb-extensions/src/main/java/com/apple/foundationdb/async/rtree/StorageAdapter.java @@ -36,7 +36,6 @@ * Storage adapter used for serialization and deserialization of nodes. */ interface StorageAdapter { - /** * Get the {@link RTree.Config} associated with this storage adapter. * @return the configuration used by this storage adapter diff --git a/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java new file mode 100644 index 0000000000..a0238fd4fe --- /dev/null +++ b/fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWModificationTest.java @@ -0,0 +1,667 @@ +/* + * HNSWModificationTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.async.hnsw; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.async.hnsw.Vector.HalfVector; +import com.apple.foundationdb.async.rtree.RTree; +import com.apple.foundationdb.test.TestDatabaseExtension; +import com.apple.foundationdb.test.TestExecutors; +import com.apple.foundationdb.test.TestSubspaceExtension; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.Tags; +import com.christianheina.langx.half4j.Half; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.NavigableSet; +import java.util.Objects; +import java.util.Random; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * Tests testing insert/update/deletes of data into/in/from {@link RTree}s. + */ +@Execution(ExecutionMode.CONCURRENT) +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +@Tag(Tags.RequiresFDB) +@Tag(Tags.Slow) +@Disabled +public class HNSWModificationTest { + private static final Logger logger = LoggerFactory.getLogger(HNSWModificationTest.class); + private static final int NUM_TEST_RUNS = 5; + private static final int NUM_SAMPLES = 10_000; + + @RegisterExtension + static final TestDatabaseExtension dbExtension = new TestDatabaseExtension(); + @RegisterExtension + TestSubspaceExtension rtSubspace = new TestSubspaceExtension(dbExtension); + @RegisterExtension + TestSubspaceExtension rtSecondarySubspace = new TestSubspaceExtension(dbExtension); + + private Database db; + + @BeforeEach + public void setUpDb() { + db = dbExtension.getDatabase(); + } + + @Test + public void testCompactSerialization() { + final Random random = new Random(0); + final CompactStorageAdapter storageAdapter = + new CompactStorageAdapter(HNSW.DEFAULT_CONFIG, CompactNode.factory(), rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP); + final Node originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node randomCompactNode = + createRandomCompactNode(random, nodeFactory, 768, 16); + + writeNode(tr, storageAdapter, randomCompactNode, 0); + return randomCompactNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, 0, originalNode.getPrimaryKey()) + .thenAccept(node -> { + Assertions.assertAll( + () -> Assertions.assertInstanceOf(CompactNode.class, node), + () -> Assertions.assertEquals(NodeKind.COMPACT, node.getKind()), + () -> Assertions.assertEquals(node.getPrimaryKey(), originalNode.getPrimaryKey()), + () -> Assertions.assertEquals(node.asCompactNode().getVector(), + originalNode.asCompactNode().getVector()), + () -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertEquals(neighbors, originalNeighbors); + } + ); + }).join()); + } + + @Test + public void testInliningSerialization() { + final Random random = new Random(0); + final InliningStorageAdapter storageAdapter = + new InliningStorageAdapter(HNSW.DEFAULT_CONFIG, InliningNode.factory(), rtSubspace.getSubspace(), + OnWriteListener.NOOP, OnReadListener.NOOP); + final Node originalNode = + db.run(tr -> { + final NodeFactory nodeFactory = storageAdapter.getNodeFactory(); + + final Node randomInliningNode = + createRandomInliningNode(random, nodeFactory, 768, 16); + + writeNode(tr, storageAdapter, randomInliningNode, 0); + return randomInliningNode; + }); + + db.run(tr -> storageAdapter.fetchNode(tr, 0, originalNode.getPrimaryKey()) + .thenAccept(node -> Assertions.assertAll( + () -> Assertions.assertInstanceOf(InliningNode.class, node), + () -> Assertions.assertEquals(NodeKind.INLINING, node.getKind()), + () -> Assertions.assertEquals(node.getPrimaryKey(), originalNode.getPrimaryKey()), + () -> { + final ArrayList neighbors = + Lists.newArrayList(node.getNeighbors()); + neighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); // should not be necessary the way it is stored + final ArrayList originalNeighbors = + Lists.newArrayList(originalNode.getNeighbors()); + originalNeighbors.sort(Comparator.comparing(NodeReference::getPrimaryKey)); + Assertions.assertEquals(neighbors, originalNeighbors); + } + )).join()); + } + + @Test + public void testBasicInsert() { + final Random random = new Random(0); + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final int dimensions = 128; + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metrics.EUCLIDEAN_METRIC.getMetric()) + .setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + for (int i = 0; i < 1000;) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> new NodeReferenceWithVector(createNextPrimaryKey(nextNodeIdAtomic), createRandomVector(random, dimensions))); + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> result = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, 10, 100, createRandomVector(random, dimensions)).join()); + final long endTs = System.nanoTime(); + + for (NodeReferenceAndNode nodeReferenceAndNode : result) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + System.out.println(onReadListener.getNodeCountByLayer()); + System.out.println(onReadListener.getBytesReadByLayer()); + + logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + } + + private int basicInsertBatch(final HNSW hnsw, final int batchSize, + @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final Function insertFunction) { + return db.run(tr -> { + onReadListener.reset(); + final long nextNodeId = nextNodeIdAtomic.get(); + final long beginTs = System.nanoTime(); + for (int i = 0; i < batchSize; i ++) { + final var newNodeReference = insertFunction.apply(tr); + if (newNodeReference != null) { + hnsw.insert(tr, newNodeReference).join(); + } + } + final long endTs = System.nanoTime(); + logger.info("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId, + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + return batchSize; + }); + } + + private int insertBatch(final HNSW hnsw, final int batchSize, + @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, + @Nonnull final Function insertFunction) { + return db.run(tr -> { + onReadListener.reset(); + final long nextNodeId = nextNodeIdAtomic.get(); + final long beginTs = System.nanoTime(); + final ImmutableList.Builder nodeReferenceWithVectorBuilder = + ImmutableList.builder(); + for (int i = 0; i < batchSize; i ++) { + final var newNodeReference = insertFunction.apply(tr); + if (newNodeReference != null) { + nodeReferenceWithVectorBuilder.add(newNodeReference); + } + } + hnsw.insertBatch(tr, nodeReferenceWithVectorBuilder.build()).join(); + final long endTs = System.nanoTime(); + logger.info("inserted batch batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}", batchSize, nextNodeId, + TimeUnit.NANOSECONDS.toMillis(endTs - beginTs), onReadListener.getNodeCountByLayer(), onReadListener.getSumMByLayer()); + return batchSize; + }); + } + + @Test + @Timeout(value = 150, unit = TimeUnit.MINUTES) + public void testSIFTInsert10k() throws Exception { + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + final int k = 10; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; + final int dimensions = 128; + + final AtomicReference queryVectorAtomic = new AtomicReference<>(); + final NavigableSet trueResults = new ConcurrentSkipListSet<>( + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + + try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { + for (int i = 0; i < 10000;) { + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + final String line; + try { + line = br.readLine(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + final String[] values = Objects.requireNonNull(line).split("\t"); + Assertions.assertEquals(dimensions, values.length); + final Half[] halfs = new Half[dimensions]; + + for (int c = 0; c < values.length; c++) { + final String value = values[c]; + halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); + } + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector currentVector = new HalfVector(halfs); + final HalfVector queryVector = queryVectorAtomic.get(); + if (queryVector == null) { + queryVectorAtomic.set(currentVector); + return null; + } else { + final double currentDistance = + Vector.comparativeDistance(metric, currentVector, queryVector); + if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) { + trueResults.add( + new NodeReferenceWithDistance(currentPrimaryKey, currentVector, + Vector.comparativeDistance(metric, currentVector, queryVector))); + } + if (trueResults.size() > k) { + trueResults.remove(trueResults.last()); + } + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); + } + }); + } + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join()); + final long endTs = System.nanoTime(); + + for (NodeReferenceAndNode nodeReferenceAndNode : results) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) { + logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + System.out.println(onReadListener.getNodeCountByLayer()); + System.out.println(onReadListener.getBytesReadByLayer()); + + logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + } + + @Test + @Timeout(value = 150, unit = TimeUnit.MINUTES) + public void testSIFTInsert10kWithBatchInsert() throws Exception { + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + final int k = 10; + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(metric).setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; + final int dimensions = 128; + + final AtomicReference queryVectorAtomic = new AtomicReference<>(); + final NavigableSet trueResults = new ConcurrentSkipListSet<>( + Comparator.comparing(NodeReferenceWithDistance::getDistance)); + + try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { + for (int i = 0; i < 10000;) { + i += insertBatch(hnsw, 100, nextNodeIdAtomic, onReadListener, + tr -> { + final String line; + try { + line = br.readLine(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + final String[] values = Objects.requireNonNull(line).split("\t"); + Assertions.assertEquals(dimensions, values.length); + final Half[] halfs = new Half[dimensions]; + + for (int c = 0; c < values.length; c++) { + final String value = values[c]; + halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); + } + final Tuple currentPrimaryKey = createNextPrimaryKey(nextNodeIdAtomic); + final HalfVector currentVector = new HalfVector(halfs); + final HalfVector queryVector = queryVectorAtomic.get(); + if (queryVector == null) { + queryVectorAtomic.set(currentVector); + return null; + } else { + final double currentDistance = + Vector.comparativeDistance(metric, currentVector, queryVector); + if (trueResults.size() < k || trueResults.last().getDistance() > currentDistance) { + trueResults.add( + new NodeReferenceWithDistance(currentPrimaryKey, currentVector, + Vector.comparativeDistance(metric, currentVector, queryVector))); + } + if (trueResults.size() > k) { + trueResults.remove(trueResults.last()); + } + return new NodeReferenceWithVector(currentPrimaryKey, currentVector); + } + }); + } + } + + onReadListener.reset(); + final long beginTs = System.nanoTime(); + final List> results = + db.run(tr -> hnsw.kNearestNeighborsSearch(tr, k, 100, queryVectorAtomic.get()).join()); + final long endTs = System.nanoTime(); + + for (NodeReferenceAndNode nodeReferenceAndNode : results) { + final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode.getNodeReferenceWithDistance(); + logger.info("retrieved result nodeId = {} at distance= {}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults) { + logger.info("true result nodeId ={} at distance={}", nodeReferenceWithDistance.getPrimaryKey().getLong(0), + nodeReferenceWithDistance.getDistance()); + } + + System.out.println(onReadListener.getNodeCountByLayer()); + System.out.println(onReadListener.getBytesReadByLayer()); + + logger.info("search transaction took elapsedTime={}ms", TimeUnit.NANOSECONDS.toMillis(endTs - beginTs)); + } + + @Test + public void testBasicInsertAndScanLayer() throws Exception { + final Random random = new Random(0); + final AtomicLong nextNodeId = new AtomicLong(0L); + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setM(4).setMMax(4).setMMax0(4).build(), + OnWriteListener.NOOP, OnReadListener.NOOP); + + db.run(tr -> { + for (int i = 0; i < 100; i ++) { + hnsw.insert(tr, createNextPrimaryKey(nextNodeId), createRandomVector(random, 2)).join(); + } + return null; + }); + + int layer = 0; + while (true) { + if (!dumpLayer(hnsw, layer++)) { + break; + } + } + } + + @Test + public void testManyRandomVectors() { + final Random random = new Random(); + for (long l = 0L; l < 3000000; l ++) { + final HalfVector randomVector = createRandomVector(random, 768); + final Tuple vectorTuple = StorageAdapter.tupleFromVector(randomVector); + final Vector roundTripVector = StorageAdapter.vectorFromTuple(vectorTuple); + Vector.comparativeDistance(Metrics.EUCLIDEAN_METRIC.getMetric(), randomVector, roundTripVector); + Assertions.assertEquals(randomVector, roundTripVector); + } + } + + @Test + @Timeout(value = 150, unit = TimeUnit.MINUTES) + public void testSIFTVectors() throws Exception { + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); + + final TestOnReadListener onReadListener = new TestOnReadListener(); + + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), + HNSW.DEFAULT_CONFIG.toBuilder().setMetric(Metrics.EUCLIDEAN_METRIC.getMetric()) + .setM(32).setMMax(32).setMMax0(64).build(), + OnWriteListener.NOOP, onReadListener); + + + final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv"; + final int dimensions = 128; + final var referenceVector = createRandomVector(new Random(0), dimensions); + long count = 0L; + double mean = 0.0d; + double mean2 = 0.0d; + + try (BufferedReader br = new BufferedReader(new FileReader(tsvFile))) { + for (int i = 0; i < 100_000; i ++) { + final String line; + try { + line = br.readLine(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + final String[] values = Objects.requireNonNull(line).split("\t"); + Assertions.assertEquals(dimensions, values.length); + final Half[] halfs = new Half[dimensions]; + for (int c = 0; c < values.length; c++) { + final String value = values[c]; + halfs[c] = HNSWHelpers.halfValueOf(Double.parseDouble(value)); + } + final HalfVector newVector = new HalfVector(halfs); + final double distance = Vector.comparativeDistance(Metrics.EUCLIDEAN_METRIC.getMetric(), + referenceVector, newVector); + count++; + final double delta = distance - mean; + mean += delta / count; + final double delta2 = distance - mean; + mean2 += delta * delta2; + } + } + final double sampleVariance = mean2 / (count - 1); + final double standardDeviation = Math.sqrt(sampleVariance); + logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation, + standardDeviation / mean); + } + + + @ParameterizedTest + @ValueSource(ints = {2, 3, 10, 100, 768}) + public void testManyVectorsStandardDeviation(final int dimensionality) { + final Random random = new Random(); + final Metric metric = Metrics.EUCLIDEAN_METRIC.getMetric(); + long count = 0L; + double mean = 0.0d; + double mean2 = 0.0d; + for (long i = 0L; i < 100000; i ++) { + final HalfVector vector1 = createRandomVector(random, dimensionality); + final HalfVector vector2 = createRandomVector(random, dimensionality); + final double distance = Vector.comparativeDistance(metric, vector1, vector2); + count = i + 1; + final double delta = distance - mean; + mean += delta / count; + final double delta2 = distance - mean; + mean2 += delta * delta2; + } + final double sampleVariance = mean2 / (count - 1); + final double standardDeviation = Math.sqrt(sampleVariance); + logger.info("mean={}, sample_variance={}, stddeviation={}, cv={}", mean, sampleVariance, standardDeviation, + standardDeviation / mean); + } + + private boolean dumpLayer(final HNSW hnsw, final int layer) throws IOException { + final String verticesFileName = "/Users/nseemann/Downloads/vertices-" + layer + ".csv"; + final String edgesFileName = "/Users/nseemann/Downloads/edges-" + layer + ".csv"; + + final AtomicLong numReadAtomic = new AtomicLong(0L); + try (final BufferedWriter verticesWriter = new BufferedWriter(new FileWriter(verticesFileName)); + final BufferedWriter edgesWriter = new BufferedWriter(new FileWriter(edgesFileName))) { + hnsw.scanLayer(db, layer, 100, node -> { + final CompactNode compactNode = node.asCompactNode(); + final Vector vector = compactNode.getVector(); + try { + verticesWriter.write(compactNode.getPrimaryKey().getLong(0) + "," + + vector.getComponent(0) + "," + + vector.getComponent(1)); + verticesWriter.newLine(); + + for (final var neighbor : compactNode.getNeighbors()) { + edgesWriter.write(compactNode.getPrimaryKey().getLong(0) + "," + + neighbor.getPrimaryKey().getLong(0)); + edgesWriter.newLine(); + } + numReadAtomic.getAndIncrement(); + } catch (final IOException e) { + throw new RuntimeException("unable to write to file", e); + } + }); + } + return numReadAtomic.get() != 0; + } + + private void writeNode(@Nonnull final Transaction transaction, + @Nonnull final StorageAdapter storageAdapter, + @Nonnull final Node node, + final int layer) { + final NeighborsChangeSet insertChangeSet = + new InsertNeighborsChangeSet<>(new BaseNeighborsChangeSet<>(ImmutableList.of()), + node.getNeighbors()); + storageAdapter.writeNode(transaction, node, layer, insertChangeSet); + } + + @Nonnull + private Node createRandomCompactNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int dimensionality, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReference(random)); + } + + return nodeFactory.create(primaryKey, createRandomVector(random, dimensionality), neighborsBuilder.build()); + } + + @Nonnull + private Node createRandomInliningNode(@Nonnull final Random random, + @Nonnull final NodeFactory nodeFactory, + final int dimensionality, + final int numberOfNeighbors) { + final Tuple primaryKey = createRandomPrimaryKey(random); + final ImmutableList.Builder neighborsBuilder = ImmutableList.builder(); + for (int i = 0; i < numberOfNeighbors; i ++) { + neighborsBuilder.add(createRandomNodeReferenceWithVector(random, dimensionality)); + } + + return nodeFactory.create(primaryKey, createRandomVector(random, dimensionality), neighborsBuilder.build()); + } + + @Nonnull + private NodeReference createRandomNodeReference(@Nonnull final Random random) { + return new NodeReference(createRandomPrimaryKey(random)); + } + + @Nonnull + private NodeReferenceWithVector createRandomNodeReferenceWithVector(@Nonnull final Random random, final int dimensionality) { + return new NodeReferenceWithVector(createRandomPrimaryKey(random), createRandomVector(random, dimensionality)); + } + + @Nonnull + private static Tuple createRandomPrimaryKey(final @Nonnull Random random) { + return Tuple.from(random.nextLong()); + } + + @Nonnull + private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic) { + return Tuple.from(nextIdAtomic.getAndIncrement()); + } + + @Nonnull + private HalfVector createRandomVector(@Nonnull final Random random, final int dimensionality) { + final Half[] components = new Half[dimensionality]; + for (int d = 0; d < dimensionality; d ++) { + // don't ask + components[d] = HNSWHelpers.halfValueOf(random.nextDouble()); + } + return new HalfVector(components); + } + + private static class TestOnReadListener implements OnReadListener { + final Map nodeCountByLayer; + final Map sumMByLayer; + final Map bytesReadByLayer; + + public TestOnReadListener() { + this.nodeCountByLayer = Maps.newConcurrentMap(); + this.sumMByLayer = Maps.newConcurrentMap(); + this.bytesReadByLayer = Maps.newConcurrentMap(); + } + + public Map getNodeCountByLayer() { + return nodeCountByLayer; + } + + public Map getBytesReadByLayer() { + return bytesReadByLayer; + } + + public Map getSumMByLayer() { + return sumMByLayer; + } + + public void reset() { + nodeCountByLayer.clear(); + bytesReadByLayer.clear(); + sumMByLayer.clear(); + } + + @Override + public void onNodeRead(final int layer, @Nonnull final Node node) { + nodeCountByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + 1L); + sumMByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + node.getNeighbors().size()); + } + + @Override + public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + bytesReadByLayer.compute(layer, (l, oldValue) -> (oldValue == null ? 0 : oldValue) + + key.length + value.length); + } + } +} diff --git a/fdb-record-layer-core/fdb-record-layer-core.gradle b/fdb-record-layer-core/fdb-record-layer-core.gradle index 41fcfd996a..e6a7416e47 100644 --- a/fdb-record-layer-core/fdb-record-layer-core.gradle +++ b/fdb-record-layer-core/fdb-record-layer-core.gradle @@ -31,6 +31,7 @@ dependencies { api(libs.protobuf) implementation(libs.slf4j.api) implementation(libs.guava) + implementation(libs.half4j) compileOnly(libs.jsr305) compileOnly(libs.autoService) annotationProcessor(libs.autoService) diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java index 2b66805b2f..21dc8f23b4 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexOptions.java @@ -223,6 +223,15 @@ public class IndexOptions { */ public static final String RTREE_USE_NODE_SLOT_INDEX = "rtreeUseNodeSlotIndex"; + public static final String HNSW_METRIC = "hnswMetric"; + public static final String HNSW_M = "hnswM"; + public static final String HNSW_M_MAX = "hnswMax"; + public static final String HNSW_M_MAX_0 = "hnswMax0"; + public static final String HNSW_EF_SEARCH = "hnswEfSearch"; + public static final String HNSW_EF_CONSTRUCTION = "hnswEfConstruction"; + public static final String HNSW_EXTEND_CANDIDATES = "hnswExtendCandidates"; + public static final String HNSW_KEEP_PRUNED_CONNECTIONS = "hnswKeepPrunedConnections"; + private IndexOptions() { } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java index 1d19171093..8d10f26d9e 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/metadata/IndexTypes.java @@ -164,6 +164,11 @@ public class IndexTypes { */ public static final String MULTIDIMENSIONAL = "multidimensional"; + /** + * An index using an HNSW structure. + */ + public static final String VECTOR = "vector"; + private IndexTypes() { } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBRecordStore.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBRecordStore.java index 23dd1f17f6..4537529749 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBRecordStore.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBRecordStore.java @@ -644,11 +644,18 @@ private FDBStoredRecord dryRunSetSizeInfo(@Nonnull Record return recordBuilder.build(); } + @SuppressWarnings("unchecked") @Nonnull private FDBStoredRecord serializeAndSaveRecord(@Nonnull RecordSerializer typedSerializer, @Nonnull final FDBStoredRecordBuilder recordBuilder, @Nonnull final RecordMetaData metaData, @Nullable FDBStoredSizes oldSizeInfo) { final Tuple primaryKey = recordBuilder.getPrimaryKey(); final FDBRecordVersion version = recordBuilder.getVersion(); + + // final M record = recordBuilder.getRecord(); + // M cleansed_rec = (M)record.toBuilder() + // .clearField(record.getDescriptorForType().findFieldByName("vector_data")) + // .build(); + final byte[] serialized = typedSerializer.serialize(metaData, recordBuilder.getRecordType(), recordBuilder.getRecord(), getTimer()); final FDBRecordVersion splitVersion = useOldVersionFormat() ? null : version; final SplitHelper.SizeInfo sizeInfo = new SplitHelper.SizeInfo(); diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java index 07d1af133d..a462cbea0a 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/FDBStoreTimer.java @@ -757,6 +757,14 @@ public enum Counts implements Count { LOCKS_ATTEMPTED("number of attempts to register a lock", false), /** Count of the locks released. */ LOCKS_RELEASED("number of locks released", false), + VECTOR_NODE_READS("intermediate nodes read", false), + VECTOR_NODE_READ_BYTES("intermediate node bytes read", true), + VECTOR_NODE0_READS("intermediate nodes read", false), + VECTOR_NODE0_READ_BYTES("intermediate node bytes read", true), + VECTOR_NODE_WRITES("intermediate nodes written", false), + VECTOR_NODE_WRITE_BYTES("intermediate node bytes written", true), + VECTOR_NODE0_WRITES("intermediate nodes written", false), + VECTOR_NODE0_WRITE_BYTES("intermediate node bytes written", true), ; private final String title; diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanBounds.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanBounds.java new file mode 100644 index 0000000000..b3131bb0cb --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanBounds.java @@ -0,0 +1,115 @@ +/* + * MultidimensionalIndexScanBounds.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.hnsw.Vector; +import com.apple.foundationdb.record.IndexScanType; +import com.apple.foundationdb.record.RecordCoreException; +import com.apple.foundationdb.record.TupleRange; +import com.apple.foundationdb.record.query.expressions.Comparisons; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * TODO. + */ +@API(API.Status.EXPERIMENTAL) +public class VectorIndexScanBounds implements IndexScanBounds { + @Nonnull + private final TupleRange prefixRange; + + @Nonnull + private final Comparisons.Type comparisonType; + @Nullable + private final Vector queryVector; + private final int limit; + + @Nonnull + private final TupleRange suffixRange; + + public VectorIndexScanBounds(@Nonnull final TupleRange prefixRange, + @Nonnull final Comparisons.Type comparisonType, + @Nullable final Vector queryVector, + final int limit, + @Nonnull final TupleRange suffixRange) { + this.prefixRange = prefixRange; + this.comparisonType = comparisonType; + this.queryVector = queryVector; + this.limit = limit; + this.suffixRange = suffixRange; + } + + @Nonnull + @Override + public IndexScanType getScanType() { + return IndexScanType.BY_VALUE; + } + + @Nonnull + public TupleRange getPrefixRange() { + return prefixRange; + } + + @Nonnull + public Comparisons.Type getComparisonType() { + return comparisonType; + } + + @Nullable + public Vector getQueryVector() { + return queryVector; + } + + public int getLimit() { + return limit; + } + + public int getAdjustedLimit() { + switch (getComparisonType()) { + case DISTANCE_RANK_LESS_THAN: + return limit - 1; + case DISTANCE_RANK_LESS_THAN_OR_EQUAL: + return limit; + default: + throw new RecordCoreException("unsupported comparison"); + } + } + + @Nonnull + public TupleRange getSuffixRange() { + return suffixRange; + } + + public boolean isWithinLimit(int rank) { + switch (getComparisonType()) { + case DISTANCE_RANK_EQUALS: + return rank == limit; + case DISTANCE_RANK_LESS_THAN: + return rank < limit; + case DISTANCE_RANK_LESS_THAN_OR_EQUAL: + return rank <= limit; + default: + throw new RecordCoreException("unsupported comparison"); + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java new file mode 100644 index 0000000000..88e69b6748 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/VectorIndexScanComparisons.java @@ -0,0 +1,332 @@ +/* + * MultidimensionalIndexScanComparisons.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2022 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.record.EvaluationContext; +import com.apple.foundationdb.record.IndexScanType; +import com.apple.foundationdb.record.PlanDeserializer; +import com.apple.foundationdb.record.PlanHashable; +import com.apple.foundationdb.record.PlanSerializationContext; +import com.apple.foundationdb.record.TupleRange; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.planprotos.PIndexScanParameters; +import com.apple.foundationdb.record.planprotos.PVectorIndexScanComparisons; +import com.apple.foundationdb.record.query.expressions.Comparisons; +import com.apple.foundationdb.record.query.expressions.Comparisons.DistanceRankValueComparison; +import com.apple.foundationdb.record.query.plan.ScanComparisons; +import com.apple.foundationdb.record.query.plan.cascades.AliasMap; +import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier; +import com.apple.foundationdb.record.query.plan.cascades.explain.Attribute; +import com.apple.foundationdb.record.query.plan.cascades.values.translation.TranslationMap; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokens; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokensWithPrecedence; +import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.Objects; +import java.util.Set; + +/** + * {@link ScanComparisons} for use in a multidimensional index scan. + */ +@API(API.Status.UNSTABLE) +public class VectorIndexScanComparisons implements IndexScanParameters { + @Nonnull + private final ScanComparisons prefixScanComparisons; + @Nonnull + private final DistanceRankValueComparison distanceRankValueComparison; + @Nonnull + private final ScanComparisons suffixScanComparisons; + + public VectorIndexScanComparisons(@Nonnull final ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nonnull final ScanComparisons suffixKeyComparisonRanges) { + this.prefixScanComparisons = prefixScanComparisons; + this.distanceRankValueComparison = distanceRankValueComparison; + this.suffixScanComparisons = suffixKeyComparisonRanges; + } + + @Nonnull + @Override + public IndexScanType getScanType() { + return IndexScanType.BY_VALUE; + } + + @Nonnull + public ScanComparisons getPrefixScanComparisons() { + return prefixScanComparisons; + } + + @Nonnull + public DistanceRankValueComparison getDistanceRankValueComparison() { + return distanceRankValueComparison; + } + + @Nonnull + public ScanComparisons getSuffixScanComparisons() { + return suffixScanComparisons; + } + + @Nonnull + @Override + public VectorIndexScanBounds bind(@Nonnull final FDBRecordStoreBase store, @Nonnull final Index index, + @Nonnull final EvaluationContext context) { + return new VectorIndexScanBounds(prefixScanComparisons.toTupleRange(store, context), + distanceRankValueComparison.getType(), distanceRankValueComparison.getVector(store, context), + distanceRankValueComparison.getLimit(store, context), suffixScanComparisons.toTupleRange(store, context)); + } + + @Override + public int planHash(@Nonnull PlanHashMode mode) { + return PlanHashable.objectsPlanHash(mode, prefixScanComparisons, distanceRankValueComparison, + suffixScanComparisons); + } + + @Override + public boolean isUnique(@Nonnull Index index) { + return prefixScanComparisons.isEquality() && prefixScanComparisons.size() == index.getColumnSize(); + } + + @Nonnull + @Override + public ExplainTokensWithPrecedence explain() { + @Nullable var tupleRange = prefixScanComparisons.toTupleRangeWithoutContext(); + final var prefix = tupleRange == null + ? prefixScanComparisons.explain().getExplainTokens() + : new ExplainTokens().addToString(tupleRange); + + ExplainTokens distanceRank; + try { + @Nullable var vector = distanceRankValueComparison.getVector(null, null); + int limit = distanceRankValueComparison.getLimit(null, null); + distanceRank = + new ExplainTokens().addNested(vector == null + ? new ExplainTokens().addKeyword("null") + : new ExplainTokens().addToString(vector)); + distanceRank.addKeyword(distanceRankValueComparison.getType().name()).addWhitespace().addToString(limit); + } catch (final Comparisons.EvaluationContextRequiredException e) { + distanceRank = + new ExplainTokens().addNested(distanceRankValueComparison.explain().getExplainTokens()); + } + + tupleRange = suffixScanComparisons.toTupleRangeWithoutContext(); + final var suffix = tupleRange == null + ? suffixScanComparisons.explain().getExplainTokens() + : new ExplainTokens().addToString(tupleRange); + + return ExplainTokensWithPrecedence.of(prefix.addOptionalWhitespace().addToString(":{").addOptionalWhitespace() + .addNested(distanceRank).addOptionalWhitespace().addToString("}:").addOptionalWhitespace().addNested(suffix)); + } + + @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance") + @Override + public void getPlannerGraphDetails(@Nonnull ImmutableList.Builder detailsBuilder, @Nonnull ImmutableMap.Builder attributeMapBuilder) { + @Nullable TupleRange tupleRange = prefixScanComparisons.toTupleRangeWithoutContext(); + if (tupleRange != null) { + detailsBuilder.add("prefix: " + tupleRange.getLowEndpoint().toString(false) + "{{plow}}, {{phigh}}" + tupleRange.getHighEndpoint().toString(true)); + attributeMapBuilder.put("plow", Attribute.gml(tupleRange.getLow() == null ? "-∞" : tupleRange.getLow().toString())); + attributeMapBuilder.put("phigh", Attribute.gml(tupleRange.getHigh() == null ? "∞" : tupleRange.getHigh().toString())); + } else { + detailsBuilder.add("prefix comparisons: {{pcomparisons}}"); + attributeMapBuilder.put("pcomparisons", Attribute.gml(prefixScanComparisons.toString())); + } + + try { + @Nullable var vector = distanceRankValueComparison.getVector(null, null); + int limit = distanceRankValueComparison.getLimit(null, null); + detailsBuilder.add("distanceRank: {{vector}} {{type}} {{limit}}"); + attributeMapBuilder.put("vector", Attribute.gml(String.valueOf(vector))); + attributeMapBuilder.put("type", Attribute.gml(distanceRankValueComparison.getType())); + attributeMapBuilder.put("limit", Attribute.gml(limit)); + } catch (final Comparisons.EvaluationContextRequiredException e) { + detailsBuilder.add("distanceRank: {{comparison}}"); + attributeMapBuilder.put("comparison", Attribute.gml(distanceRankValueComparison)); + } + + tupleRange = suffixScanComparisons.toTupleRangeWithoutContext(); + if (tupleRange != null) { + detailsBuilder.add("suffix: " + tupleRange.getLowEndpoint().toString(false) + "{{slow}}, {{shigh}}" + tupleRange.getHighEndpoint().toString(true)); + attributeMapBuilder.put("slow", Attribute.gml(tupleRange.getLow() == null ? "-∞" : tupleRange.getLow().toString())); + attributeMapBuilder.put("shigh", Attribute.gml(tupleRange.getHigh() == null ? "∞" : tupleRange.getHigh().toString())); + } else { + detailsBuilder.add("suffix comparisons: {{scomparisons}}"); + attributeMapBuilder.put("scomparisons", Attribute.gml(suffixScanComparisons.toString())); + } + } + + @Nonnull + @Override + public Set getCorrelatedTo() { + final ImmutableSet.Builder correlatedToBuilder = ImmutableSet.builder(); + correlatedToBuilder.addAll(prefixScanComparisons.getCorrelatedTo()); + correlatedToBuilder.addAll(distanceRankValueComparison.getCorrelatedTo()); + correlatedToBuilder.addAll(suffixScanComparisons.getCorrelatedTo()); + return correlatedToBuilder.build(); + } + + @Nonnull + @Override + public IndexScanParameters rebase(@Nonnull final AliasMap translationMap) { + return translateCorrelations(TranslationMap.rebaseWithAliasMap(translationMap), false); + } + + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public boolean semanticEquals(@Nullable final Object other, @Nonnull final AliasMap aliasMap) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + + final VectorIndexScanComparisons that = (VectorIndexScanComparisons)other; + + if (!prefixScanComparisons.semanticEquals(that.prefixScanComparisons, aliasMap)) { + return false; + } + + if (!distanceRankValueComparison.semanticEquals(that.distanceRankValueComparison, aliasMap)) { + return false; + } + return suffixScanComparisons.semanticEquals(that.suffixScanComparisons, aliasMap); + } + + @Override + public int semanticHashCode() { + int hashCode = prefixScanComparisons.semanticHashCode(); + hashCode = 31 * hashCode + distanceRankValueComparison.semanticHashCode(); + return 31 * hashCode + suffixScanComparisons.semanticHashCode(); + } + + @Nonnull + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public IndexScanParameters translateCorrelations(@Nonnull final TranslationMap translationMap, + final boolean shouldSimplifyValues) { + final ScanComparisons translatedPrefixScanComparisons = + prefixScanComparisons.translateCorrelations(translationMap, shouldSimplifyValues); + + final DistanceRankValueComparison translatedDistanceRankValueComparison = + distanceRankValueComparison.translateCorrelations(translationMap, shouldSimplifyValues); + + final ScanComparisons translatedSuffixKeyScanComparisons = + suffixScanComparisons.translateCorrelations(translationMap, shouldSimplifyValues); + + if (translatedPrefixScanComparisons != prefixScanComparisons || + translatedDistanceRankValueComparison != distanceRankValueComparison || + translatedSuffixKeyScanComparisons != suffixScanComparisons) { + return withComparisons(translatedPrefixScanComparisons, translatedDistanceRankValueComparison, + translatedSuffixKeyScanComparisons); + } + return this; + } + + @Nonnull + protected VectorIndexScanComparisons withComparisons(@Nonnull final ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nonnull final ScanComparisons suffixKeyScanComparisons) { + return new VectorIndexScanComparisons(prefixScanComparisons, distanceRankValueComparison, + suffixKeyScanComparisons); + } + + @Override + public String toString() { + return "BY_VALUE(VECTOR):" + prefixScanComparisons + ":" + distanceRankValueComparison + ":" + suffixScanComparisons; + } + + @Override + @SpotBugsSuppressWarnings("EQ_UNUSUAL") + @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") + public boolean equals(final Object o) { + return semanticEquals(o, AliasMap.emptyMap()); + } + + @Override + public int hashCode() { + return semanticHashCode(); + } + + @Nonnull + @Override + public PVectorIndexScanComparisons toProto(@Nonnull final PlanSerializationContext serializationContext) { + final PVectorIndexScanComparisons.Builder builder = PVectorIndexScanComparisons.newBuilder(); + builder.setPrefixScanComparisons(prefixScanComparisons.toProto(serializationContext)); + builder.setDistanceRankValueComparison(distanceRankValueComparison.toProto(serializationContext)); + builder.setSuffixScanComparisons(suffixScanComparisons.toProto(serializationContext)); + return builder.build(); + } + + @Nonnull + @Override + public PIndexScanParameters toIndexScanParametersProto(@Nonnull final PlanSerializationContext serializationContext) { + return PIndexScanParameters.newBuilder().setVectorIndexScanComparisons(toProto(serializationContext)).build(); + } + + @Nonnull + public static VectorIndexScanComparisons fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PVectorIndexScanComparisons vectorIndexScanComparisonsProto) { + return new VectorIndexScanComparisons(ScanComparisons.fromProto(serializationContext, + Objects.requireNonNull(vectorIndexScanComparisonsProto.getPrefixScanComparisons())), + Objects.requireNonNull(DistanceRankValueComparison.fromProto(serializationContext, vectorIndexScanComparisonsProto.getDistanceRankValueComparison())), + ScanComparisons.fromProto(serializationContext, Objects.requireNonNull(vectorIndexScanComparisonsProto.getSuffixScanComparisons()))); + } + + @Nonnull + public static VectorIndexScanComparisons byValue(@Nullable ScanComparisons prefixScanComparisons, + @Nonnull final DistanceRankValueComparison distanceRankValueComparison, + @Nullable ScanComparisons suffixKeyScanComparisons) { + if (prefixScanComparisons == null) { + prefixScanComparisons = ScanComparisons.EMPTY; + } + + if (suffixKeyScanComparisons == null) { + suffixKeyScanComparisons = ScanComparisons.EMPTY; + } + + return new VectorIndexScanComparisons(prefixScanComparisons, distanceRankValueComparison, suffixKeyScanComparisons); + } + + /** + * Deserializer. + */ + @AutoService(PlanDeserializer.class) + public static class Deserializer implements PlanDeserializer { + @Nonnull + @Override + public Class getProtoMessageClass() { + return PVectorIndexScanComparisons.class; + } + + @Nonnull + @Override + public VectorIndexScanComparisons fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PVectorIndexScanComparisons vectorIndexScanComparisonsProto) { + return VectorIndexScanComparisons.fromProto(serializationContext, vectorIndexScanComparisonsProto); + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java new file mode 100644 index 0000000000..eab705e67d --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexHelper.java @@ -0,0 +1,113 @@ +/* + * VectorIndexHelper.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.hnsw.HNSW; +import com.apple.foundationdb.async.hnsw.Metrics; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.metadata.IndexOptions; +import com.apple.foundationdb.record.provider.common.StoreTimer; + +import javax.annotation.Nonnull; + +/** + * Helper functions for index maintainers that use a {@link HNSW}. + */ +@API(API.Status.EXPERIMENTAL) +public class VectorIndexHelper { + private VectorIndexHelper() { + } + + /** + * Parse standard options into {@link HNSW.Config}. + * @param index the index definition to get options from + * @return parsed config options + */ + public static HNSW.Config getConfig(@Nonnull final Index index) { + final HNSW.ConfigBuilder builder = HNSW.newConfigBuilder(); + final String hnswMetricOption = index.getOption(IndexOptions.HNSW_METRIC); + if (hnswMetricOption != null) { + builder.setMetric(Metrics.valueOf(hnswMetricOption).getMetric()); + } + final String hnswMOption = index.getOption(IndexOptions.HNSW_M); + if (hnswMOption != null) { + builder.setM(Integer.parseInt(hnswMOption)); + } + final String hnswMMaxOption = index.getOption(IndexOptions.HNSW_M_MAX); + if (hnswMMaxOption != null) { + builder.setMMax(Integer.parseInt(hnswMMaxOption)); + } + final String hnswMMax0Option = index.getOption(IndexOptions.HNSW_M_MAX_0); + if (hnswMMax0Option != null) { + builder.setMMax0(Integer.parseInt(hnswMMax0Option)); + } + final String hnswEfSearchOption = index.getOption(IndexOptions.HNSW_EF_SEARCH); + if (hnswEfSearchOption != null) { + builder.setEfSearch(Integer.parseInt(hnswEfSearchOption)); + } + final String hnswEfConstructionOption = index.getOption(IndexOptions.HNSW_EF_CONSTRUCTION); + if (hnswEfConstructionOption != null) { + builder.setEfConstruction(Integer.parseInt(hnswEfConstructionOption)); + } + final String hnswExtendCandidatesOption = index.getOption(IndexOptions.HNSW_EXTEND_CANDIDATES); + if (hnswExtendCandidatesOption != null) { + builder.setExtendCandidates(Boolean.parseBoolean(hnswExtendCandidatesOption)); + } + final String hnswKeepPrunedConnectionsOption = index.getOption(IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS); + if (hnswKeepPrunedConnectionsOption != null) { + builder.setKeepPrunedConnections(Boolean.parseBoolean(hnswKeepPrunedConnectionsOption)); + } + + return builder.build(); + } + + /** + * Instrumentation events specific to R-tree index maintenance. + */ + public enum Events implements StoreTimer.DetailEvent { + VECTOR_SCAN("scanning the HNSW of a vector index"), + VECTOR_SKIP_SCAN("skip scan the prefix tuples of a vector index scan"); + + private final String title; + private final String logKey; + + Events(String title, String logKey) { + this.title = title; + this.logKey = (logKey != null) ? logKey : StoreTimer.DetailEvent.super.logKey(); + } + + Events(String title) { + this(title, null); + } + + @Override + public String title() { + return title; + } + + @Override + @Nonnull + public String logKey() { + return this.logKey; + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java new file mode 100644 index 0000000000..b48abdedd5 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainer.java @@ -0,0 +1,495 @@ +/* + * VectorIndexMaintainer.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.ReadTransaction; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.hnsw.HNSW; +import com.apple.foundationdb.async.hnsw.HNSW.Config; +import com.apple.foundationdb.async.hnsw.Node; +import com.apple.foundationdb.async.hnsw.NodeReference; +import com.apple.foundationdb.async.hnsw.NodeReferenceAndNode; +import com.apple.foundationdb.async.hnsw.NodeReferenceWithDistance; +import com.apple.foundationdb.async.hnsw.OnReadListener; +import com.apple.foundationdb.async.hnsw.OnWriteListener; +import com.apple.foundationdb.async.hnsw.Vector; +import com.apple.foundationdb.record.CursorStreamingMode; +import com.apple.foundationdb.record.EndpointType; +import com.apple.foundationdb.record.ExecuteProperties; +import com.apple.foundationdb.record.IndexEntry; +import com.apple.foundationdb.record.IndexScanType; +import com.apple.foundationdb.record.PipelineOperation; +import com.apple.foundationdb.record.RecordCoreException; +import com.apple.foundationdb.record.RecordCursor; +import com.apple.foundationdb.record.RecordCursorContinuation; +import com.apple.foundationdb.record.RecordCursorProto; +import com.apple.foundationdb.record.ScanProperties; +import com.apple.foundationdb.record.TupleRange; +import com.apple.foundationdb.record.cursors.AsyncLockCursor; +import com.apple.foundationdb.record.cursors.ChainedCursor; +import com.apple.foundationdb.record.cursors.LazyCursor; +import com.apple.foundationdb.record.cursors.ListCursor; +import com.apple.foundationdb.record.locking.LockIdentifier; +import com.apple.foundationdb.record.metadata.Key; +import com.apple.foundationdb.record.metadata.expressions.KeyExpression; +import com.apple.foundationdb.record.metadata.expressions.KeyWithValueExpression; +import com.apple.foundationdb.record.provider.common.StoreTimer; +import com.apple.foundationdb.record.provider.foundationdb.FDBIndexableRecord; +import com.apple.foundationdb.record.provider.foundationdb.FDBStoreTimer; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerState; +import com.apple.foundationdb.record.provider.foundationdb.IndexScanBounds; +import com.apple.foundationdb.record.provider.foundationdb.KeyValueCursor; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanBounds; +import com.apple.foundationdb.record.query.QueryToKeyMatcher; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.ByteArrayUtil2; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.foundationdb.tuple.TupleHelpers; +import com.google.common.base.Verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * An index maintainer for keeping an {@link HNSW}. + */ +@API(API.Status.EXPERIMENTAL) +public class VectorIndexMaintainer extends StandardIndexMaintainer { + @Nonnull + private final Config config; + + public VectorIndexMaintainer(IndexMaintainerState state) { + super(state); + this.config = VectorIndexHelper.getConfig(state.index); + } + + @Nonnull + public Config getConfig() { + return config; + } + + @SuppressWarnings("resource") + @Nonnull + @Override + public RecordCursor scan(@Nonnull final IndexScanBounds scanBounds, @Nullable final byte[] continuation, + @Nonnull final ScanProperties scanProperties) { + if (!scanBounds.getScanType().equals(IndexScanType.BY_VALUE)) { + throw new RecordCoreException("Can only scan vector index by value."); + } + if (!(scanBounds instanceof VectorIndexScanBounds)) { + throw new RecordCoreException("Need proper vector index scan bounds."); + } + final VectorIndexScanBounds vectorIndexScanBounds = (VectorIndexScanBounds)scanBounds; + + final KeyWithValueExpression keyWithValueExpression = getKeyWithValueExpression(state.index.getRootExpression()); + final int prefixSize = keyWithValueExpression.getSplitPoint(); + + final ExecuteProperties executeProperties = scanProperties.getExecuteProperties(); + final ScanProperties innerScanProperties = scanProperties.with(ExecuteProperties::clearSkipAndLimit); + final Subspace indexSubspace = getIndexSubspace(); + final FDBStoreTimer timer = Objects.requireNonNull(state.context.getTimer()); + + // + // Skip-scan through the prefixes in a way that we only consider each distinct prefix. That skip scan + // forms the outer of a join with an inner that searches the R-tree for that prefix using the + // spatial predicates of the scan bounds. + // + return RecordCursor.flatMapPipelined(prefixSkipScan(prefixSize, timer, vectorIndexScanBounds, innerScanProperties), + (prefixTuple, innerContinuation) -> { + final Subspace hnswSubspace; + if (prefixTuple != null) { + Verify.verify(prefixTuple.size() == prefixSize); + hnswSubspace = indexSubspace.subspace(prefixTuple); + } else { + hnswSubspace = indexSubspace; + } + + if (innerContinuation != null) { + final RecordCursorProto.VectorIndexScanContinuation parsedContinuation = + Continuation.fromBytes(innerContinuation); + final ImmutableList.Builder indexEntriesBuilder = ImmutableList.builder(); + for (int i = 0; i < parsedContinuation.getIndexEntriesCount(); i ++) { + final RecordCursorProto.VectorIndexScanContinuation.IndexEntry indexEntryProto = + parsedContinuation.getIndexEntries(i); + indexEntriesBuilder.add(new IndexEntry(state.index, + Tuple.fromBytes(indexEntryProto.getKey().toByteArray()), + Tuple.fromBytes(indexEntryProto.getValue().toByteArray()))); + } + final ImmutableList indexEntries = indexEntriesBuilder.build(); + return new ListCursor<>(indexEntries, parsedContinuation.getInnerContinuation().toByteArray()) + .mapResult(result -> + result.withContinuation(new Continuation(indexEntries, result.getContinuation()))); + } + + final HNSW hnsw = new HNSW(hnswSubspace, getExecutor(), getConfig(), + OnWriteListener.NOOP, new OnRead(timer)); + final ReadTransaction transaction = state.context.readTransaction(true); + return new LazyCursor<>( + state.context.acquireReadLock(new LockIdentifier(hnswSubspace)) + .thenApply(lock -> + new AsyncLockCursor<>(lock, + new LazyCursor<>( + kNearestNeighborSearch(prefixTuple, hnsw, transaction, vectorIndexScanBounds), + getExecutor()))), + state.context.getExecutor()); + }, + continuation, + state.store.getPipelineSize(PipelineOperation.INDEX_TO_RECORD)) + .skipThenLimit(executeProperties.getSkip(), executeProperties.getReturnedRowLimit()); + } + + @SuppressWarnings({"resource", "checkstyle:MethodName"}) + @Nonnull + private CompletableFuture> kNearestNeighborSearch(@Nullable final Tuple prefixTuple, + @Nonnull final HNSW hnsw, + @Nonnull final ReadTransaction transaction, + @Nonnull final VectorIndexScanBounds vectorIndexScanBounds) { + return hnsw.kNearestNeighborsSearch(transaction, vectorIndexScanBounds.getAdjustedLimit(), 100, + Objects.requireNonNull(vectorIndexScanBounds.getQueryVector()).toHalfVector()) + .thenApply(nearestNeighbors -> { + final ImmutableList.Builder nearestNeighborEntriesBuilder = ImmutableList.builder(); + for (final NodeReferenceAndNode nearestNeighbor : nearestNeighbors) { + if (vectorIndexScanBounds.getSuffixRange().contains(nearestNeighbor.getNode().getPrimaryKey())) { + nearestNeighborEntriesBuilder.add(toIndexEntry(prefixTuple, nearestNeighbor)); + } + } + final ImmutableList nearestNeighborsEntries = nearestNeighborEntriesBuilder.build(); + return new ListCursor<>(getExecutor(), nearestNeighborsEntries, 0) + .mapResult(result -> { + final RecordCursorContinuation continuation = result.getContinuation(); + if (continuation.isEnd()) { + return result; + } + return result.withContinuation(new Continuation(nearestNeighborsEntries, continuation)); + }); + }); + } + + @Nonnull + private IndexEntry toIndexEntry(final Tuple prefixTuple, final NodeReferenceAndNode nearestNeighbor) { + final List keyItems = Lists.newArrayList(); + if (prefixTuple != null) { + keyItems.addAll(prefixTuple.getItems()); + } + final Node node = nearestNeighbor.getNode(); + final NodeReferenceWithDistance nodeReferenceWithDistance = + nearestNeighbor.getNodeReferenceWithDistance(); + keyItems.addAll(node.getPrimaryKey().getItems()); + final List valueItems = Lists.newArrayList(); + valueItems.add(nodeReferenceWithDistance.getVector().getRawData()); + return new IndexEntry(state.index, Tuple.fromList(keyItems), + Tuple.fromList(valueItems)); + } + + @Nonnull + @Override + public RecordCursor scan(@Nonnull final IndexScanType scanType, @Nonnull final TupleRange range, + @Nullable final byte[] continuation, @Nonnull final ScanProperties scanProperties) { + throw new RecordCoreException("index maintainer does not support this scan api"); + } + + @Nonnull + private Function> prefixSkipScan(final int prefixSize, + @Nonnull final StoreTimer timer, + @Nonnull final VectorIndexScanBounds vectorIndexScanBounds, + @Nonnull final ScanProperties innerScanProperties) { + final Function> outerFunction; + if (prefixSize > 0) { + outerFunction = outerContinuation -> timer.instrument(MultiDimensionalIndexHelper.Events.MULTIDIMENSIONAL_SKIP_SCAN, + new ChainedCursor<>(state.context, + lastKeyOptional -> nextPrefixTuple(vectorIndexScanBounds.getPrefixRange(), + prefixSize, lastKeyOptional.orElse(null), innerScanProperties), + Tuple::pack, + Tuple::fromBytes, + outerContinuation, + innerScanProperties)); + } else { + outerFunction = outerContinuation -> RecordCursor.fromFuture(CompletableFuture.completedFuture(null)); + } + return outerFunction; + } + + @SuppressWarnings({"resource", "PMD.CloseResource"}) + private CompletableFuture> nextPrefixTuple(@Nonnull final TupleRange prefixRange, + final int prefixSize, + @Nullable final Tuple lastPrefixTuple, + @Nonnull final ScanProperties scanProperties) { + final Subspace indexSubspace = getIndexSubspace(); + final KeyValueCursor cursor; + if (lastPrefixTuple == null) { + cursor = KeyValueCursor.Builder.withSubspace(indexSubspace) + .setContext(state.context) + .setRange(prefixRange) + .setContinuation(null) + .setScanProperties(scanProperties.setStreamingMode(CursorStreamingMode.ITERATOR) + .with(innerExecuteProperties -> innerExecuteProperties.setReturnedRowLimit(1))) + .build(); + } else { + KeyValueCursor.Builder builder = KeyValueCursor.Builder.withSubspace(indexSubspace) + .setContext(state.context) + .setContinuation(null) + .setScanProperties(scanProperties) + .setScanProperties(scanProperties.setStreamingMode(CursorStreamingMode.ITERATOR) + .with(innerExecuteProperties -> innerExecuteProperties.setReturnedRowLimit(1))); + + cursor = builder.setLow(indexSubspace.pack(lastPrefixTuple), EndpointType.RANGE_EXCLUSIVE) + .setHigh(prefixRange.getHigh(), prefixRange.getHighEndpoint()) + .build(); + } + + return cursor.onNext().thenApply(next -> { + cursor.close(); + if (next.hasNext()) { + final KeyValue kv = Objects.requireNonNull(next.get()); + return Optional.of(TupleHelpers.subTuple(indexSubspace.unpack(kv.getKey()), 0, prefixSize)); + } + return Optional.empty(); + }); + } + + @Override + protected CompletableFuture updateIndexKeys(@Nonnull final FDBIndexableRecord savedRecord, + final boolean remove, + @Nonnull final List indexEntries) { + final KeyWithValueExpression keyWithValueExpression = getKeyWithValueExpression(state.index.getRootExpression()); + final int prefixSize = keyWithValueExpression.getColumnSize(); + final Subspace indexSubspace = getIndexSubspace(); + final var futures = indexEntries.stream().map(indexEntry -> { + final var indexKeyItems = indexEntry.getKey().getItems(); + final Tuple prefixKey = Tuple.fromList(indexKeyItems.subList(0, prefixSize)); + + final Subspace rtSubspace; + if (prefixSize > 0) { + rtSubspace = indexSubspace.subspace(prefixKey); + } else { + rtSubspace = indexSubspace; + } + return state.context.doWithWriteLock(new LockIdentifier(rtSubspace), () -> { + final List primaryKeyParts = Lists.newArrayList(savedRecord.getPrimaryKey().getItems()); + state.index.trimPrimaryKey(primaryKeyParts); + final Tuple trimmedPrimaryKey = Tuple.fromList(primaryKeyParts); + final FDBStoreTimer timer = Objects.requireNonNull(getTimer()); + final HNSW hnsw = + new HNSW(rtSubspace, getExecutor(), getConfig(), new OnWrite(timer), OnReadListener.NOOP); + if (remove) { + throw new UnsupportedOperationException("not implemented"); + } else { + return hnsw.insert(state.transaction, trimmedPrimaryKey, + Vector.HalfVector.halfVectorFromBytes(indexEntry.getValue().getBytes(0))); + } + }); + }).collect(Collectors.toList()); + return AsyncUtil.whenAll(futures); + } + + @Override + public boolean canDeleteWhere(@Nonnull final QueryToKeyMatcher matcher, @Nonnull final Key.Evaluated evaluated) { + if (!super.canDeleteWhere(matcher, evaluated)) { + return false; + } + return evaluated.size() <= getKeyWithValueExpression(state.index.getRootExpression()).getColumnSize(); + } + + @Override + public CompletableFuture deleteWhere(@Nonnull final Transaction tr, @Nonnull final Tuple prefix) { + Verify.verify(getKeyWithValueExpression(state.index.getRootExpression()).getColumnSize() >= prefix.size()); + return super.deleteWhere(tr, prefix); + } + + /** + * TODO. + */ + @Nonnull + private static KeyWithValueExpression getKeyWithValueExpression(@Nonnull final KeyExpression root) { + if (root instanceof KeyWithValueExpression) { + return (KeyWithValueExpression)root; + } + throw new RecordCoreException("structure of vector index is not supported"); + } + + static class OnRead implements OnReadListener { + @Nonnull + private final FDBStoreTimer timer; + + public OnRead(@Nonnull final FDBStoreTimer timer) { + this.timer = timer; + } + + @Override + public CompletableFuture> onAsyncRead(@Nonnull final CompletableFuture> future) { + return timer.instrument(VectorIndexHelper.Events.VECTOR_SCAN, future); + } + + @Override + public void onNodeRead(final int layer, @Nonnull final Node node) { + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_READS); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_READS); + } + } + + @Override + public void onKeyValueRead(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + final int keyLength = key.length; + final int valueLength = value.length; + + timer.increment(FDBStoreTimer.Counts.LOAD_INDEX_KEY); + timer.increment(FDBStoreTimer.Counts.LOAD_INDEX_KEY_BYTES, keyLength); + timer.increment(FDBStoreTimer.Counts.LOAD_INDEX_VALUE_BYTES, valueLength); + + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_READ_BYTES); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_READ_BYTES); + } + } + } + + static class OnWrite implements OnWriteListener { + @Nonnull + private final FDBStoreTimer timer; + + public OnWrite(@Nonnull final FDBStoreTimer timer) { + this.timer = timer; + } + + @Override + public void onNodeWritten(final int layer, @Nonnull final Node node) { + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_WRITES); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_WRITES); + } + } + + @Override + public void onKeyValueWritten(final int layer, @Nonnull final byte[] key, @Nonnull final byte[] value) { + final int keyLength = key.length; + final int valueLength = value.length; + + final int totalLength = keyLength + valueLength; + timer.increment(FDBStoreTimer.Counts.SAVE_INDEX_KEY); + timer.increment(FDBStoreTimer.Counts.SAVE_INDEX_KEY_BYTES, keyLength); + timer.increment(FDBStoreTimer.Counts.SAVE_INDEX_VALUE_BYTES, valueLength); + + if (layer == 0) { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE0_WRITE_BYTES, totalLength); + } else { + timer.increment(FDBStoreTimer.Counts.VECTOR_NODE_WRITE_BYTES, totalLength); + } + } + } + + private static class Continuation implements RecordCursorContinuation { + @Nonnull + private final List indexEntries; + @Nonnull + private final RecordCursorContinuation innerContinuation; + + @Nullable + private ByteString cachedByteString; + @Nullable + private byte[] cachedBytes; + + private Continuation(@Nonnull final List indexEntries, + @Nonnull final RecordCursorContinuation innerContinuation) { + this.indexEntries = ImmutableList.copyOf(indexEntries); + this.innerContinuation = innerContinuation; + } + + @Nonnull + public List getIndexEntries() { + return indexEntries; + } + + @Nonnull + public RecordCursorContinuation getInnerContinuation() { + return innerContinuation; + } + + @Nonnull + @Override + public ByteString toByteString() { + if (isEnd()) { + return ByteString.EMPTY; + } + + if (cachedByteString == null) { + final RecordCursorProto.VectorIndexScanContinuation.Builder builder = + RecordCursorProto.VectorIndexScanContinuation.newBuilder(); + for (final var indexEntry : indexEntries) { + builder.addIndexEntries(RecordCursorProto.VectorIndexScanContinuation.IndexEntry.newBuilder() + .setKey(ByteString.copyFrom(indexEntry.getKey().pack())) + .setValue(ByteString.copyFrom(indexEntry.getKey().pack())) + .build()); + } + + cachedByteString = builder + .setInnerContinuation(Objects.requireNonNull(innerContinuation.toByteString())) + .build() + .toByteString(); + } + return cachedByteString; + } + + @Nullable + @Override + public byte[] toBytes() { + if (isEnd()) { + return null; + } + if (cachedBytes == null) { + cachedBytes = toByteString().toByteArray(); + } + return cachedBytes; + } + + @Override + public boolean isEnd() { + return getInnerContinuation().isEnd(); + } + + @Nonnull + private static RecordCursorProto.VectorIndexScanContinuation fromBytes(@Nonnull byte[] continuationBytes) { + try { + return RecordCursorProto.VectorIndexScanContinuation.parseFrom(continuationBytes); + } catch (InvalidProtocolBufferException ex) { + throw new RecordCoreException("error parsing continuation", ex) + .addLogInfo("raw_bytes", ByteArrayUtil2.loggable(continuationBytes)); + } + } + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java new file mode 100644 index 0000000000..d8b2444eef --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexMaintainerFactory.java @@ -0,0 +1,159 @@ +/* + * VectorIndexMaintainerFactory.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.async.hnsw.HNSW.Config; +import com.apple.foundationdb.record.logging.LogMessageKeys; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.metadata.IndexOptions; +import com.apple.foundationdb.record.metadata.IndexTypes; +import com.apple.foundationdb.record.metadata.IndexValidator; +import com.apple.foundationdb.record.metadata.MetaDataException; +import com.apple.foundationdb.record.metadata.MetaDataValidator; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainer; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerFactory; +import com.apple.foundationdb.record.provider.foundationdb.IndexMaintainerState; +import com.google.auto.service.AutoService; + +import javax.annotation.Nonnull; +import java.util.Arrays; +import java.util.Set; + +/** + * A factory for {@link VectorIndexMaintainer} indexes. + */ +@AutoService(IndexMaintainerFactory.class) +@API(API.Status.EXPERIMENTAL) +public class VectorIndexMaintainerFactory implements IndexMaintainerFactory { + static final String[] TYPES = { IndexTypes.VECTOR}; + + @Override + @Nonnull + public Iterable getIndexTypes() { + return Arrays.asList(TYPES); + } + + @Override + @Nonnull + public IndexValidator getIndexValidator(Index index) { + return new IndexValidator(index) { + @Override + public void validate(@Nonnull MetaDataValidator metaDataValidator) { + super.validate(metaDataValidator); + validateNotVersion(); + validateStructure(); + } + + /** + * TODO. + */ + private void validateStructure() { + // + // There is no structural constraint on the key expression of the index. We just happen to interpret + // things in specific ways: + // + // - without GroupingKeyExpression: + // - one HNSW for the entire table (ungrouped HNSW) + // - first column of the expression gives us access to the field containing the vector + // - with GroupingKeyExpression: + // - one HNSW for each grouping prefix + // - first column in the grouped expression gives us access to the field containing the vector + // + // In any case, the vector is always a half-precision-encoded vector of dimensionality + // blob.length / 2 (for now). + // + // TODO We do not support extraneous columns to support advanced covering index scans for now. That + // Will probably encoded by a KeyWithValueExpression in the root position (but not now) + // + } + + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public void validateChangedOptions(@Nonnull final Index oldIndex, + @Nonnull final Set changedOptions) { + if (!changedOptions.isEmpty()) { + // Allow changing from unspecified to the default (or vice versa), but not otherwise. + final Config oldOptions = VectorIndexHelper.getConfig(oldIndex); + final Config newOptions = VectorIndexHelper.getConfig(index); + if (changedOptions.contains(IndexOptions.HNSW_METRIC)) { + if (oldOptions.getMetric() != newOptions.getMetric()) { + throw new MetaDataException("HNSW metric changed", + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(IndexOptions.HNSW_METRIC); + } + if (changedOptions.contains(IndexOptions.HNSW_M)) { + if (oldOptions.getM() != newOptions.getM()) { + throw new MetaDataException("HNSW M changed", + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(IndexOptions.HNSW_M); + } + if (changedOptions.contains(IndexOptions.HNSW_M_MAX)) { + if (oldOptions.getMMax() != newOptions.getMMax()) { + throw new MetaDataException("HNSW mMax changed", + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(IndexOptions.HNSW_M_MAX); + } + if (changedOptions.contains(IndexOptions.HNSW_M_MAX_0)) { + if (oldOptions.getMMax0() != newOptions.getMMax0()) { + throw new MetaDataException("HNSW mMax0 changed", + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(IndexOptions.HNSW_M_MAX_0); + } + // efSearch can be overridden in every scenario + changedOptions.remove(IndexOptions.HNSW_EF_SEARCH); + if (changedOptions.contains(IndexOptions.HNSW_EF_CONSTRUCTION)) { + if (oldOptions.getEfConstruction() != newOptions.getEfConstruction()) { + throw new MetaDataException("HNSW efConstruction changed", + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(IndexOptions.HNSW_EF_CONSTRUCTION); + } + if (changedOptions.contains(IndexOptions.HNSW_EXTEND_CANDIDATES)) { + if (oldOptions.isExtendCandidates() != newOptions.isExtendCandidates()) { + throw new MetaDataException("HNSW extendCandidates changed", + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(IndexOptions.HNSW_EXTEND_CANDIDATES); + } + if (changedOptions.contains(IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS)) { + if (oldOptions.isKeepPrunedConnections() != newOptions.isKeepPrunedConnections()) { + throw new MetaDataException("HNSW keepPrunedConnections changed", + LogMessageKeys.INDEX_NAME, index.getName()); + } + changedOptions.remove(IndexOptions.HNSW_KEEP_PRUNED_CONNECTIONS); + } + } + super.validateChangedOptions(oldIndex, changedOptions); + } + }; + } + + @Override + @Nonnull + public IndexMaintainer getIndexMaintainer(@Nonnull final IndexMaintainerState state) { + return new VectorIndexMaintainer(state); + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java index 1305ab01c4..731f499f09 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/expressions/Comparisons.java @@ -22,6 +22,7 @@ import com.apple.foundationdb.annotation.API; import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; +import com.apple.foundationdb.async.hnsw.Vector; import com.apple.foundationdb.record.Bindings; import com.apple.foundationdb.record.EvaluationContext; import com.apple.foundationdb.record.ObjectPlanHash; @@ -38,6 +39,7 @@ import com.apple.foundationdb.record.metadata.expressions.TupleFieldsHelper; import com.apple.foundationdb.record.planprotos.PComparison; import com.apple.foundationdb.record.planprotos.PComparison.PComparisonType; +import com.apple.foundationdb.record.planprotos.PDistanceRankValueComparison; import com.apple.foundationdb.record.planprotos.PInvertedFunctionComparison; import com.apple.foundationdb.record.planprotos.PListComparison; import com.apple.foundationdb.record.planprotos.PMultiColumnComparison; @@ -56,9 +58,6 @@ import com.apple.foundationdb.record.query.plan.cascades.ConstrainedBoolean; import com.apple.foundationdb.record.query.plan.cascades.Correlated; import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier; -import com.apple.foundationdb.record.query.plan.explain.DefaultExplainFormatter; -import com.apple.foundationdb.record.query.plan.explain.ExplainTokens; -import com.apple.foundationdb.record.query.plan.explain.ExplainTokensWithPrecedence; import com.apple.foundationdb.record.query.plan.cascades.UsesValueEquivalence; import com.apple.foundationdb.record.query.plan.cascades.ValueEquivalence; import com.apple.foundationdb.record.query.plan.cascades.WithValue; @@ -68,6 +67,9 @@ import com.apple.foundationdb.record.query.plan.cascades.values.QuantifiedObjectValue; import com.apple.foundationdb.record.query.plan.cascades.values.Value; import com.apple.foundationdb.record.query.plan.cascades.values.translation.TranslationMap; +import com.apple.foundationdb.record.query.plan.explain.DefaultExplainFormatter; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokens; +import com.apple.foundationdb.record.query.plan.explain.ExplainTokensWithPrecedence; import com.apple.foundationdb.record.query.plan.plans.QueryResult; import com.apple.foundationdb.record.query.plan.serialization.PlanSerialization; import com.apple.foundationdb.record.util.ProtoUtils; @@ -632,7 +634,13 @@ public enum Type { @API(API.Status.EXPERIMENTAL) SORT(false), @API(API.Status.EXPERIMENTAL) - LIKE; + LIKE, + @API(API.Status.EXPERIMENTAL) + DISTANCE_RANK_EQUALS(true), + @API(API.Status.EXPERIMENTAL) + DISTANCE_RANK_LESS_THAN, + @API(API.Status.EXPERIMENTAL) + DISTANCE_RANK_LESS_THAN_OR_EQUAL; @Nonnull private static final Supplier> protoEnumBiMapSupplier = @@ -1504,6 +1512,12 @@ public static class ValueComparison implements Comparison { @Nonnull private final Supplier hashCodeSupplier; + protected ValueComparison(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PValueComparison valueComparisonProto) { + this(Type.fromProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getType())), + Value.fromValueProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getComparandValue()))); + } + public ValueComparison(@Nonnull final Type type, @Nonnull final Value comparandValue) { this(type, comparandValue, ParameterRelationshipGraph.unbound()); @@ -1660,7 +1674,7 @@ public int hashCode() { } public int computeHashCode() { - return Objects.hash(type.name(), relatedByEquality()); + return Objects.hash(type.name(), getComparandValue(), relatedByEquality()); } private Set relatedByEquality() { @@ -1687,7 +1701,12 @@ public Comparison withParameterRelationshipMap(@Nonnull final ParameterRelations @Nonnull @Override - public PValueComparison toProto(@Nonnull final PlanSerializationContext serializationContext) { + public Message toProto(@Nonnull final PlanSerializationContext serializationContext) { + return toValueComparisonProto(serializationContext); + } + + @Nonnull + public PValueComparison toValueComparisonProto(@Nonnull final PlanSerializationContext serializationContext) { return PValueComparison.newBuilder() .setType(type.toProto(serializationContext)) .setComparandValue(comparandValue.toValueProto(serializationContext)) @@ -1697,14 +1716,13 @@ public PValueComparison toProto(@Nonnull final PlanSerializationContext serializ @Nonnull @Override public PComparison toComparisonProto(@Nonnull final PlanSerializationContext serializationContext) { - return PComparison.newBuilder().setValueComparison(toProto(serializationContext)).build(); + return PComparison.newBuilder().setValueComparison(toValueComparisonProto(serializationContext)).build(); } @Nonnull public static ValueComparison fromProto(@Nonnull final PlanSerializationContext serializationContext, @Nonnull final PValueComparison valueComparisonProto) { - return new ValueComparison(Type.fromProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getType())), - Value.fromValueProto(serializationContext, Objects.requireNonNull(valueComparisonProto.getComparandValue()))); + return new ValueComparison(serializationContext, valueComparisonProto); } /** @@ -1727,6 +1745,236 @@ public ValueComparison fromProto(@Nonnull final PlanSerializationContext seriali } } + public static class DistanceRankValueComparison extends ValueComparison { + private static final ObjectPlanHash BASE_HASH = new ObjectPlanHash("Distance-Rank-Value-Comparison"); + + @Nonnull + private final Value limitValue; + + protected DistanceRankValueComparison(@Nonnull PlanSerializationContext serializationContext, + @Nonnull final PDistanceRankValueComparison distanceRankValueComparisonProto) { + super(serializationContext, distanceRankValueComparisonProto.getSuper()); + this.limitValue = Value.fromValueProto(serializationContext, + Objects.requireNonNull(distanceRankValueComparisonProto.getLimitValue())); + } + + public DistanceRankValueComparison(@Nonnull final Type type, @Nonnull final Value comparandValue, + @Nonnull final Value limitValue) { + this(type, comparandValue, ParameterRelationshipGraph.unbound(), limitValue); + } + + public DistanceRankValueComparison(@Nonnull final Type type, @Nonnull final Value comparandValue, + @Nonnull final ParameterRelationshipGraph parameterRelationshipGraph, + @Nonnull final Value limitValue) { + super(type, comparandValue, parameterRelationshipGraph); + Verify.verify(type == Type.DISTANCE_RANK_EQUALS || + type == Type.DISTANCE_RANK_LESS_THAN || + type == Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL); + this.limitValue = limitValue; + } + + @Nonnull + public Value getLimitValue() { + return limitValue; + } + + @Nonnull + @Override + public Comparison withType(@Nonnull final Type newType) { + if (getType() == newType) { + return this; + } + return new ValueComparison(newType, getComparandValue(), parameterRelationshipGraph); + } + + @Nonnull + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public ValueComparison withValue(@Nonnull final Value value) { + if (getComparandValue() == value) { + return this; + } + return new ValueComparison(getType(), value); + } + + @Nonnull + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public Optional replaceValuesMaybe(@Nonnull final Function> replacementFunction) { + return replacementFunction.apply(getComparandValue()) + .flatMap(replacedComparandValue -> + replacementFunction.apply(getLimitValue()).map(replacedLimitValue -> { + if (replacedComparandValue == getComparandValue() && + replacedLimitValue == getLimitValue()) { + return this; + } + return new DistanceRankValueComparison(getType(), replacedComparandValue, parameterRelationshipGraph, + replacedLimitValue); + })); + } + + @Nonnull + @Override + public DistanceRankValueComparison translateCorrelations(@Nonnull final TranslationMap translationMap, + final boolean shouldSimplifyValues) { + if (getComparandValue().getCorrelatedTo() + .stream() + .noneMatch(translationMap::containsSourceAlias) && + getLimitValue().getCorrelatedTo() + .stream() + .noneMatch(translationMap::containsSourceAlias)) { + return this; + } + + return new DistanceRankValueComparison(getType(), + getComparandValue().translateCorrelations(translationMap, shouldSimplifyValues), + parameterRelationshipGraph, + getLimitValue().translateCorrelations(translationMap, shouldSimplifyValues)); + } + + @Nonnull + @Override + public Set getCorrelatedTo() { + return ImmutableSet.builder() + .addAll(getComparandValue().getCorrelatedTo()) + .addAll(getLimitValue().getCorrelatedTo()) + .build(); + } + + @Nonnull + @Override + public ConstrainedBoolean semanticEqualsTyped(@Nonnull final Comparison other, @Nonnull final ValueEquivalence valueEquivalence) { + return super.semanticEqualsTyped(other, valueEquivalence) + .compose(ignored -> getLimitValue() + .semanticEquals(((DistanceRankValueComparison)other).getLimitValue(), + valueEquivalence)); + } + + @Nullable + @Override + @SuppressWarnings("PMD.CompareObjectsWithEquals") + public Boolean eval(@Nullable FDBRecordStoreBase store, @Nonnull EvaluationContext context, @Nullable Object v) { + throw new RecordCoreException("this comparison can only be evaluated using an index"); + } + + @Nonnull + @Override + public String typelessString() { + return getComparandValue() + ":" + getLimitValue(); + } + + @Override + public final boolean equals(final Object o) { + if (!(o instanceof DistanceRankValueComparison)) { + return false; + } + final DistanceRankValueComparison that = (DistanceRankValueComparison)o; + if (!super.equals(o)) { + return false; + } + + return limitValue.equals(that.limitValue); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + limitValue.hashCode(); + return result; + } + + @Override + public String toString() { + return explain().getExplainTokens().render(DefaultExplainFormatter.forDebugging()).toString(); + } + + @Nonnull + @Override + public ExplainTokensWithPrecedence explain() { + return ExplainTokensWithPrecedence.of(new ExplainTokens().addKeyword(getType().name()) + .addWhitespace().addNested(getComparandValue().explain().getExplainTokens()) + .addKeyword(":").addWhitespace() + .addNested(getLimitValue().explain().getExplainTokens())); + } + + @Override + public int computeHashCode() { + return Objects.hash(getType().name(), getComparandValue(), getLimitValue()); + } + + @Override + public int planHash(@Nonnull final PlanHashMode mode) { + switch (mode.getKind()) { + case LEGACY: + case FOR_CONTINUATION: + return PlanHashable.objectsPlanHash(mode, BASE_HASH, getType(), getComparandValue(), getLimitValue()); + default: + throw new UnsupportedOperationException("Hash Kind " + mode.name() + " is not supported"); + } + } + + @Nonnull + @Override + public Comparison withParameterRelationshipMap(@Nonnull final ParameterRelationshipGraph parameterRelationshipGraph) { + Verify.verify(this.parameterRelationshipGraph.isUnbound()); + return new DistanceRankValueComparison(getType(), getComparandValue(), parameterRelationshipGraph, + getLimitValue()); + } + + @Nonnull + @Override + public PDistanceRankValueComparison toProto(@Nonnull final PlanSerializationContext serializationContext) { + return PDistanceRankValueComparison.newBuilder() + .setSuper(super.toValueComparisonProto(serializationContext)) + .setLimitValue(getLimitValue().toValueProto(serializationContext)) + .build(); + } + + @Nonnull + @Override + public PComparison toComparisonProto(@Nonnull final PlanSerializationContext serializationContext) { + return PComparison.newBuilder().setDistanceRankValueComparison(toProto(serializationContext)).build(); + } + + @Nullable + public Vector getVector(@Nullable final FDBRecordStoreBase store, final @Nullable EvaluationContext context) { + final Object comparand = getComparand(store, context); + return comparand == null ? null : Vector.HalfVector.halfVectorFromBytes((byte[])comparand); + } + + public int getLimit(@Nullable final FDBRecordStoreBase store, final @Nullable EvaluationContext context) { + if (context == null) { + throw EvaluationContextRequiredException.instance(); + } + return (int)Objects.requireNonNull(getLimitValue().eval(store, context)); + } + + @Nonnull + public static DistanceRankValueComparison fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PDistanceRankValueComparison distanceRankValueComparisonProto) { + return new DistanceRankValueComparison(serializationContext, distanceRankValueComparisonProto); + } + + /** + * Deserializer. + */ + @AutoService(PlanDeserializer.class) + public static class Deserializer implements PlanDeserializer { + @Nonnull + @Override + public Class getProtoMessageClass() { + return PDistanceRankValueComparison.class; + } + + @Nonnull + @Override + public DistanceRankValueComparison fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PDistanceRankValueComparison distanceRankValueComparisonProto) { + return DistanceRankValueComparison.fromProto(serializationContext, distanceRankValueComparisonProto); + } + } + } + /** * A comparison with a list of values. */ diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/Type.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/Type.java index e6cc85c390..e72a4ea65e 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/Type.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/Type.java @@ -24,6 +24,7 @@ import com.apple.foundationdb.record.PlanSerializable; import com.apple.foundationdb.record.PlanSerializationContext; import com.apple.foundationdb.record.RecordCoreException; +import com.apple.foundationdb.record.RecordMetaDataOptionsProto; import com.apple.foundationdb.record.TupleFieldsProto; import com.apple.foundationdb.record.logging.LogMessageKeys; import com.apple.foundationdb.record.planprotos.PType; @@ -38,6 +39,7 @@ import com.apple.foundationdb.record.planprotos.PType.PRelationType; import com.apple.foundationdb.record.planprotos.PType.PTypeCode; import com.apple.foundationdb.record.planprotos.PType.PUuidType; +import com.apple.foundationdb.record.planprotos.PType.PVectorType; import com.apple.foundationdb.record.provider.foundationdb.FDBRecordVersion; import com.apple.foundationdb.record.query.plan.cascades.Narrowable; import com.apple.foundationdb.record.query.plan.cascades.NullableArrayTypeUtils; @@ -416,12 +418,18 @@ static List fromTyped(@Nonnull List typedList) { private static Type fromProtoType(@Nullable Descriptors.GenericDescriptor descriptor, @Nonnull Descriptors.FieldDescriptor.Type protoType, @Nonnull FieldDescriptorProto.Label protoLabel, + @Nullable DescriptorProtos.FieldOptions fieldOptions, boolean isNullable) { - final var typeCode = TypeCode.fromProtobufType(protoType); + final var typeCode = TypeCode.fromProtobufFieldDescriptor(protoType, fieldOptions); if (protoLabel == FieldDescriptorProto.Label.LABEL_REPEATED) { // collection type - return fromProtoTypeToArray(descriptor, protoType, typeCode, false); + return fromProtoTypeToArray(descriptor, protoType, typeCode, fieldOptions, false); } else if (typeCode.isPrimitive()) { + final var fieldOptionMaybe = Optional.ofNullable(fieldOptions).map(f -> f.getExtension(RecordMetaDataOptionsProto.field)); + if (fieldOptionMaybe.isPresent() && fieldOptionMaybe.get().hasVectorOptions()) { + final var vectorOptions = fieldOptionMaybe.get().getVectorOptions(); + return Type.Vector.of(isNullable, vectorOptions.getPrecision(), vectorOptions.getDimensions()); + } return primitiveType(typeCode, isNullable); } else if (typeCode == TypeCode.ENUM) { final var enumDescriptor = (Descriptors.EnumDescriptor)Objects.requireNonNull(descriptor); @@ -432,8 +440,10 @@ private static Type fromProtoType(@Nullable Descriptors.GenericDescriptor descri if (NullableArrayTypeUtils.describesWrappedArray(messageDescriptor)) { // find TypeCode of array elements final var elementField = messageDescriptor.findFieldByName(NullableArrayTypeUtils.getRepeatedFieldName()); - final var elementTypeCode = TypeCode.fromProtobufType(elementField.getType()); - return fromProtoTypeToArray(descriptor, protoType, elementTypeCode, true); + final var elementTypeCode = TypeCode.fromProtobufFieldDescriptor(elementField.getType(), elementField.getOptions()); + return fromProtoTypeToArray(descriptor, protoType, elementTypeCode, fieldOptions, true); + } else if (TupleFieldsProto.UUID.getDescriptor().equals(messageDescriptor)) { + return Type.uuidType(isNullable); } else { return Record.fromFieldDescriptorsMap(isNullable, Record.toFieldDescriptorMap(messageDescriptor.getFields())); } @@ -451,7 +461,9 @@ private static Type fromProtoType(@Nullable Descriptors.GenericDescriptor descri @Nonnull private static Array fromProtoTypeToArray(@Nullable Descriptors.GenericDescriptor descriptor, @Nonnull Descriptors.FieldDescriptor.Type protoType, - @Nonnull TypeCode typeCode, boolean isNullable) { + @Nonnull TypeCode typeCode, + @Nullable DescriptorProtos.FieldOptions fieldOptions, + boolean isNullable) { if (typeCode.isPrimitive()) { final var primitiveType = primitiveType(typeCode, false); return new Array(isNullable, primitiveType); @@ -463,10 +475,10 @@ private static Array fromProtoTypeToArray(@Nullable Descriptors.GenericDescripto if (isNullable) { Descriptors.Descriptor wrappedDescriptor = ((Descriptors.Descriptor)Objects.requireNonNull(descriptor)).findFieldByName(NullableArrayTypeUtils.getRepeatedFieldName()).getMessageType(); Objects.requireNonNull(wrappedDescriptor); - return new Array(true, fromProtoType(wrappedDescriptor, Descriptors.FieldDescriptor.Type.MESSAGE, FieldDescriptorProto.Label.LABEL_OPTIONAL, false)); + return new Array(true, fromProtoType(wrappedDescriptor, Descriptors.FieldDescriptor.Type.MESSAGE, FieldDescriptorProto.Label.LABEL_OPTIONAL, fieldOptions, false)); } else { // case 2: any arbitrary sub message we don't understand - return new Array(false, fromProtoType(descriptor, protoType, FieldDescriptorProto.Label.LABEL_OPTIONAL, false)); + return new Array(false, fromProtoType(descriptor, protoType, FieldDescriptorProto.Label.LABEL_OPTIONAL, fieldOptions, false)); } } } @@ -707,6 +719,7 @@ enum TypeCode { INT(Integer.class, FieldDescriptorProto.Type.TYPE_INT32, true, true), LONG(Long.class, FieldDescriptorProto.Type.TYPE_INT64, true, true), STRING(String.class, FieldDescriptorProto.Type.TYPE_STRING, true, false), + VECTOR(Vector.JavaVectorType.class, FieldDescriptorProto.Type.TYPE_BYTES, true, false), VERSION(FDBRecordVersion.class, FieldDescriptorProto.Type.TYPE_BYTES, true, false), ENUM(Enum.class, FieldDescriptorProto.Type.TYPE_ENUM, false, false), RECORD(Message.class, null, false, false), @@ -813,11 +826,12 @@ private static BiMap, TypeCode> computeClassToTypeCodeMap() { /** * Generates a {@link TypeCode} that corresponds to the given protobuf * {@link com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type}. - * @param protobufType The protobuf type. + * @param protobufType The protobuf descriptor of the type. * @return A corresponding {@link TypeCode} instance. */ @Nonnull - public static TypeCode fromProtobufType(@Nonnull final Descriptors.FieldDescriptor.Type protobufType) { + public static TypeCode fromProtobufFieldDescriptor(@Nonnull final Descriptors.FieldDescriptor.Type protobufType, + @Nullable final DescriptorProtos.FieldOptions fieldOptions) { switch (protobufType) { case DOUBLE: return TypeCode.DOUBLE; @@ -845,7 +859,15 @@ public static TypeCode fromProtobufType(@Nonnull final Descriptors.FieldDescript case MESSAGE: return TypeCode.RECORD; case BYTES: + { + if (fieldOptions != null) { + final var recordTypeOptions = fieldOptions.getExtension(RecordMetaDataOptionsProto.field); + if (recordTypeOptions.hasVectorOptions()) { + return TypeCode.VECTOR; + } + } return TypeCode.BYTES; + } default: throw new IllegalArgumentException("unknown protobuf type " + protobufType); } @@ -953,6 +975,7 @@ class Primitive implements Type { @Nonnull private final TypeCode typeCode; + @Nonnull private final Supplier hashCodeSupplier = Suppliers.memoize(this::computeHashCode); private Primitive(final boolean isNullable, @Nonnull final TypeCode typeCode) { @@ -985,12 +1008,12 @@ public void addProtoField(@Nonnull final TypeRepository.Builder typeRepositoryBu @Nonnull final Optional ignored, @Nonnull final FieldDescriptorProto.Label label) { final var protoType = Objects.requireNonNull(getTypeCode().getProtoType()); - descriptorBuilder.addField(FieldDescriptorProto.newBuilder() + final var fieldDescriptorBuilder = FieldDescriptorProto.newBuilder() .setNumber(fieldNumber) .setName(fieldName) .setType(protoType) - .setLabel(label) - .build()); + .setLabel(label); + descriptorBuilder.addField(fieldDescriptorBuilder.build()); } @Override @@ -1174,6 +1197,163 @@ public boolean equals(final Object other) { } } + class Vector implements Type { + private final boolean isNullable; + private final int precision; + private final int dimensions; + + private Vector(final boolean isNullable, final int precision, final int dimensions) { + this.isNullable = isNullable; + this.precision = precision; + this.dimensions = dimensions; + } + + @Nonnull + @SuppressWarnings("PMD.ReplaceVectorWithList") + public static Vector of(final boolean isNullable, final int precision, final int dimensions) { + return new Vector(isNullable, precision, dimensions); + } + + @Override + public TypeCode getTypeCode() { + return TypeCode.VECTOR; + } + + @Override + public boolean isPrimitive() { + return true; + } + + @Override + public boolean isNullable() { + return isNullable; + } + + @Nonnull + @Override + public Type withNullability(final boolean newIsNullable) { + if (isNullable == newIsNullable) { + return this; + } + return new Vector(newIsNullable, precision, dimensions); + } + + public int getPrecision() { + return precision; + } + + public int getDimensions() { + return dimensions; + } + + @Nonnull + @Override + public ExplainTokens describe() { + final var resultExplainTokens = new ExplainTokens(); + resultExplainTokens.addKeyword(getTypeCode().toString()); + return resultExplainTokens.addOptionalWhitespace().addOpeningParen().addOptionalWhitespace() + .addNested(new ExplainTokens().addToString(precision).addToString(", ").addToString(dimensions)).addOptionalWhitespace() + .addClosingParen(); + } + + @Override + public void addProtoField(@Nonnull final TypeRepository.Builder typeRepositoryBuilder, + @Nonnull final DescriptorProto.Builder descriptorBuilder, final int fieldNumber, + @Nonnull final String fieldName, @Nonnull final Optional typeNameOptional, + @Nonnull final FieldDescriptorProto.Label label) { + final var protoType = Objects.requireNonNull(getTypeCode().getProtoType()); + FieldDescriptorProto.Builder builder = FieldDescriptorProto.newBuilder() + .setNumber(fieldNumber) + .setName(fieldName) + .setType(protoType) + .setLabel(label); + final var fieldOptions = RecordMetaDataOptionsProto.FieldOptions.newBuilder() + .setVectorOptions( + RecordMetaDataOptionsProto.FieldOptions.VectorOptions + .newBuilder() + .setPrecision(precision) + .setDimensions(dimensions) + .build()) + .build(); + builder.getOptionsBuilder().setExtension(RecordMetaDataOptionsProto.field, fieldOptions); + typeNameOptional.ifPresent(builder::setTypeName); + descriptorBuilder.addField(builder); + } + + @Nonnull + @Override + public PType toTypeProto(@Nonnull final PlanSerializationContext serializationContext) { + return PType.newBuilder().setVectorType(toProto(serializationContext)).build(); + } + + @Nonnull + @Override + public PVectorType toProto(@Nonnull final PlanSerializationContext serializationContext) { + final PVectorType.Builder vectorTypeBuilder = PVectorType.newBuilder() + .setIsNullable(isNullable) + .setDimensions(dimensions) + .setPrecision(precision); + return vectorTypeBuilder.build(); + } + + @Nonnull + @SuppressWarnings("PMD.ReplaceVectorWithList") + public static Vector fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PVectorType vectorTypeProto) { + Verify.verify(vectorTypeProto.hasIsNullable()); + return new Vector(vectorTypeProto.getIsNullable(), vectorTypeProto.getPrecision(), vectorTypeProto.getDimensions()); + } + + static final class JavaVectorType { + private final ByteString underlying; + + JavaVectorType(@Nonnull final ByteString underlying) { + this.underlying = underlying; + } + + public ByteString getUnderlying() { + return underlying; + } + } + + /** + * Deserializer. + */ + @AutoService(PlanDeserializer.class) + public static class Deserializer implements PlanDeserializer { + @Nonnull + @Override + public Class getProtoMessageClass() { + return PVectorType.class; + } + + @Nonnull + @Override + @SuppressWarnings("PMD.ReplaceVectorWithList") + public Vector fromProto(@Nonnull final PlanSerializationContext serializationContext, + @Nonnull final PVectorType vectorTypeProto) { + return Vector.fromProto(serializationContext, vectorTypeProto); + } + } + + @Override + public int hashCode() { + return Objects.hash(getTypeCode().name(), precision, dimensions); + } + + @Override + @SuppressWarnings("PMD.ReplaceVectorWithList") + public boolean equals(final Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + final Vector vector = (Vector)o; + return isNullable == vector.isNullable + && precision == vector.precision + && dimensions == vector.dimensions; + } + } + /** * The none type is an unresolved type meaning that an entity returning a none type should resolve the * type to a regular type as the runtime does not support a none-typed data producer. Only the empty array constant @@ -2248,10 +2428,12 @@ public static Record fromFieldDescriptorsMap(final boolean isNullable, @Nonnull final var fieldsBuilder = ImmutableList.builder(); for (final var entry : Objects.requireNonNull(fieldDescriptorMap).entrySet()) { final var fieldDescriptor = entry.getValue(); + final var fieldOptions = fieldDescriptor.getOptions(); fieldsBuilder.add( new Field(fromProtoType(getTypeSpecificDescriptor(fieldDescriptor), fieldDescriptor.getType(), fieldDescriptor.toProto().getLabel(), + fieldOptions, !fieldDescriptor.isRequired()), Optional.of(entry.getKey()), Optional.of(fieldDescriptor.getNumber()))); diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/TypeRepository.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/TypeRepository.java index c8aa4d2d2e..30e6de5c3d 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/TypeRepository.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/query/plan/cascades/typing/TypeRepository.java @@ -20,6 +20,7 @@ package com.apple.foundationdb.record.query.plan.cascades.typing; +import com.apple.foundationdb.record.RecordMetaDataOptionsProto; import com.apple.foundationdb.record.TupleFieldsProto; import com.google.common.base.Preconditions; import com.google.common.base.Verify; @@ -67,7 +68,7 @@ public class TypeRepository { public static final TypeRepository EMPTY_SCHEMA = empty(); @Nonnull - public static final List DEPENDENCIES = List.of(TupleFieldsProto.getDescriptor()); + public static final List DEPENDENCIES = List.of(TupleFieldsProto.getDescriptor(), RecordMetaDataOptionsProto.getDescriptor()); @Nonnull private final FileDescriptorSet fileDescSet; diff --git a/fdb-record-layer-core/src/main/proto/record_cursor.proto b/fdb-record-layer-core/src/main/proto/record_cursor.proto index 04095bf779..ddeeff03e7 100644 --- a/fdb-record-layer-core/src/main/proto/record_cursor.proto +++ b/fdb-record-layer-core/src/main/proto/record_cursor.proto @@ -127,6 +127,15 @@ message MultidimensionalIndexScanContinuation { optional bytes lastKey = 2; } +message VectorIndexScanContinuation { + message IndexEntry { + optional bytes key = 1; + optional bytes value = 2; + } + repeated IndexEntry indexEntries = 1; + optional bytes inner_continuation = 2; +} + message TempTableInsertContinuation { optional bytes child_continuation = 1; optional planprotos.PTempTable tempTable = 2; diff --git a/fdb-record-layer-core/src/main/proto/record_metadata_options.proto b/fdb-record-layer-core/src/main/proto/record_metadata_options.proto index 26ecbeb261..34721c09c5 100644 --- a/fdb-record-layer-core/src/main/proto/record_metadata_options.proto +++ b/fdb-record-layer-core/src/main/proto/record_metadata_options.proto @@ -60,8 +60,13 @@ message FieldOptions { repeated Index.Option options = 3; // Note: there is no way to specify these in a .proto file. } optional IndexOption index = 3; + message VectorOptions { + optional int32 precision = 1 [default = 16]; + optional int32 dimensions = 2 [default = 768]; + } + optional VectorOptions vectorOptions = 4; } extend google.protobuf.FieldOptions { - optional FieldOptions field = 1233; + optional FieldOptions field = 1239; } diff --git a/fdb-record-layer-core/src/main/proto/record_query_plan.proto b/fdb-record-layer-core/src/main/proto/record_query_plan.proto index e10c907404..c2fe657786 100644 --- a/fdb-record-layer-core/src/main/proto/record_query_plan.proto +++ b/fdb-record-layer-core/src/main/proto/record_query_plan.proto @@ -52,6 +52,7 @@ message PType { RELATION = 15; NONE = 16; UUID = 17; + VECTOR = 18; } message PPrimitiveType { @@ -75,6 +76,12 @@ message PType { // nothing } + message PVectorType { + optional bool is_nullable = 1; + optional int32 precision = 2; + optional int32 dimensions = 3; + } + message PAnyRecordType { optional bool is_nullable = 1; } @@ -123,6 +130,7 @@ message PType { PArrayType array_type = 8; PAnyRecordType any_record_type = 9; PUuidType uuid_type = 10; + PVectorType vector_type = 11; } } @@ -1189,6 +1197,9 @@ message PComparison { TEXT_CONTAINS_ANY_PREFIX = 17; SORT = 18; LIKE = 19; + DISTANCE_RANK_EQUALS = 20; + DISTANCE_RANK_LESS_THAN = 21; + DISTANCE_RANK_LESS_THAN_OR_EQUAL = 22; } extensions 5000 to max; @@ -1205,6 +1216,7 @@ message PComparison { PRecordTypeComparison record_type_comparison = 10; PConversionSimpleComparison conversion_simple_comparison = 11; PConversionParameterComparison conversion_parameter_comparison = 12; + PDistanceRankValueComparison distance_rank_value_comparison = 13; } } @@ -1271,6 +1283,11 @@ message PRecordTypeComparison { optional string record_type_name = 1; } +message PDistanceRankValueComparison { + optional PValueComparison super = 1; + optional PValue limitValue = 2; +} + // // Query Predicates // @@ -1614,6 +1631,7 @@ message PIndexScanParameters { PIndexScanComparisons index_scan_comparisons = 2; PMultidimensionalIndexScanComparisons multidimensional_index_scan_comparisons = 3; PTimeWindowScanComparisons time_window_scan_comparisons = 4; + PVectorIndexScanComparisons vector_index_scan_comparisons = 5; } } @@ -1649,6 +1667,12 @@ message PTimeWindowScanComparisons { optional PTimeWindowForFunction time_window = 2; } +message PVectorIndexScanComparisons { + optional PScanComparisons prefix_scan_comparisons = 1; + optional PDistanceRankValueComparison distance_rank_value_comparison = 2; + optional PScanComparisons suffix_scan_comparisons = 3; +} + enum PIndexFetchMethod { SCAN_AND_FETCH = 1; USE_REMOTE_FETCH = 2; diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexSimpleTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexSimpleTest.java new file mode 100644 index 0000000000..6124371121 --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexSimpleTest.java @@ -0,0 +1,51 @@ +/* + * MultidimensionalIndexTestBase.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.test.Tags; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Tests for multidimensional type indexes. + */ +@Tag(Tags.RequiresFDB) +public class VectorIndexSimpleTest extends VectorIndexTestBase { + private static final Logger logger = LoggerFactory.getLogger(VectorIndexSimpleTest.class); + + @Override + @Test + void basicWriteReadTest() throws Exception { + super.basicWriteReadTest(); + } + + @Test + void basicWriteIndexReadTest() throws Exception { + super.basicWriteIndexReadTest(); + } + + @Test + void basicWriteIndexReadGroupedTest() throws Exception { + super.basicWriteIndexReadGroupedTest(); + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java new file mode 100644 index 0000000000..c03f3691f3 --- /dev/null +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/foundationdb/indexes/VectorIndexTestBase.java @@ -0,0 +1,284 @@ +/* + * MultidimensionalIndexTestBase.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.foundationdb.indexes; + +import com.apple.foundationdb.async.AsyncUtil; +import com.apple.foundationdb.async.hnsw.Metrics; +import com.apple.foundationdb.async.hnsw.Vector; +import com.apple.foundationdb.record.IndexFetchMethod; +import com.apple.foundationdb.record.RecordCursorIterator; +import com.apple.foundationdb.record.RecordMetaData; +import com.apple.foundationdb.record.RecordMetaDataBuilder; +import com.apple.foundationdb.record.metadata.Index; +import com.apple.foundationdb.record.metadata.IndexOptions; +import com.apple.foundationdb.record.metadata.IndexTypes; +import com.apple.foundationdb.record.metadata.expressions.KeyWithValueExpression; +import com.apple.foundationdb.record.provider.foundationdb.FDBQueriedRecord; +import com.apple.foundationdb.record.provider.foundationdb.FDBRecordContext; +import com.apple.foundationdb.record.provider.foundationdb.FDBStoredRecord; +import com.apple.foundationdb.record.provider.foundationdb.VectorIndexScanComparisons; +import com.apple.foundationdb.record.provider.foundationdb.query.FDBRecordStoreQueryTestBase; +import com.apple.foundationdb.record.query.expressions.Comparisons; +import com.apple.foundationdb.record.query.expressions.Comparisons.DistanceRankValueComparison; +import com.apple.foundationdb.record.query.plan.QueryPlanConstraint; +import com.apple.foundationdb.record.query.plan.ScanComparisons; +import com.apple.foundationdb.record.query.plan.cascades.typing.Type; +import com.apple.foundationdb.record.query.plan.cascades.values.LiteralValue; +import com.apple.foundationdb.record.query.plan.plans.RecordQueryFetchFromPartialRecordPlan; +import com.apple.foundationdb.record.query.plan.plans.RecordQueryIndexPlan; +import com.apple.foundationdb.record.vector.TestRecordsVectorsProto; +import com.apple.foundationdb.record.vector.TestRecordsVectorsProto.VectorRecord; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.Tags; +import com.christianheina.langx.half4j.Half; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Tag; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.function.Function; + +import static com.apple.foundationdb.record.metadata.Key.Expressions.concat; +import static com.apple.foundationdb.record.metadata.Key.Expressions.field; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** + * Tests for multidimensional type indexes. + */ +@Tag(Tags.RequiresFDB) +public abstract class VectorIndexTestBase extends FDBRecordStoreQueryTestBase { + private static final Logger logger = LoggerFactory.getLogger(VectorIndexTestBase.class); + + private static final SimpleDateFormat timeFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss"); + + @CanIgnoreReturnValue + RecordMetaDataBuilder addVectorIndexes(@Nonnull final RecordMetaDataBuilder metaDataBuilder) { + metaDataBuilder.addIndex("VectorRecord", + new Index("UngroupedVectorIndex", new KeyWithValueExpression(field("vector_data"), 0), + IndexTypes.VECTOR, + ImmutableMap.of(IndexOptions.HNSW_METRIC, Metrics.EUCLIDEAN_METRIC.toString()))); + metaDataBuilder.addIndex("VectorRecord", + new Index("GroupedVectorIndex", new KeyWithValueExpression(concat(field("group_id"), field("vector_data")), 1), + IndexTypes.VECTOR, + ImmutableMap.of(IndexOptions.HNSW_METRIC, Metrics.EUCLIDEAN_METRIC.toString()))); + return metaDataBuilder; + } + + protected void openRecordStore(FDBRecordContext context) throws Exception { + openRecordStore(context, NO_HOOK); + } + + protected void openRecordStore(final FDBRecordContext context, final RecordMetaDataHook hook) throws Exception { + RecordMetaDataBuilder metaDataBuilder = RecordMetaData.newBuilder().setRecords(TestRecordsVectorsProto.getDescriptor()); + metaDataBuilder.getRecordType("VectorRecord").setPrimaryKey(field("rec_no")); + hook.apply(metaDataBuilder); + createOrOpenRecordStore(context, metaDataBuilder.getRecordMetaData()); + } + + static Function getRecordGenerator(@Nonnull final Random random) { + return recNo -> { + final byte[] vector = randomVectorData(random, 128); + random.nextBytes(vector); + + return VectorRecord.newBuilder() + .setRecNo(recNo) + .setVectorData(ByteString.copyFrom(vector)) + .setGroupId(recNo.intValue() % 2) + .build(); + }; + } + + @Nonnull + static byte[] randomVectorData(final Random random, final int dimensions) { + // we do this in this convoluted way to make sure we won't get NaNs and other special surprises + final Half[] componentData = new Half[dimensions]; + for (int i = 0; i < componentData.length; i++) { + componentData[i] = Half.valueOf(random.nextFloat()); + } + + Vector.HalfVector vector = new Vector.HalfVector(componentData); + return vector.getRawData(); + } + + public void saveRecords(final boolean useAsync, @Nonnull final RecordMetaDataHook hook, @Nonnull final Random random, + final int numSamples) { + final var recordGenerator = getRecordGenerator(random); + if (useAsync) { + Assertions.assertDoesNotThrow(() -> batchAsync(hook, numSamples, 100, recNo -> recordStore.saveRecordAsync(recordGenerator.apply(recNo)))); + } else { + Assertions.assertDoesNotThrow(() -> batch(hook, numSamples, 100, recNo -> recordStore.saveRecord(recordGenerator.apply(recNo)))); + } + } + + private long batch(final RecordMetaDataHook hook, final int numRecords, final int batchSize, Consumer recordConsumer) throws Exception { + long numRecordsCommitted = 0; + while (numRecordsCommitted < numRecords) { + try (FDBRecordContext context = openContext()) { + openRecordStore(context, hook); + int recNoInBatch; + + for (recNoInBatch = 0; numRecordsCommitted + recNoInBatch < numRecords && recNoInBatch < batchSize; recNoInBatch++) { + recordConsumer.accept(numRecordsCommitted + recNoInBatch); + } + commit(context); + numRecordsCommitted += recNoInBatch; + logger.info("committed batch, numRecordsCommitted = {}", numRecordsCommitted); + } + } + return numRecordsCommitted; + } + + private long batchAsync(final RecordMetaDataHook hook, final int numRecords, final int batchSize, Function> recordConsumer) throws Exception { + long numRecordsCommitted = 0; + while (numRecordsCommitted < numRecords) { + try (FDBRecordContext context = openContext()) { + openRecordStore(context, hook); + int recNoInBatch; + final var futures = new ArrayList>(); + + for (recNoInBatch = 0; numRecordsCommitted + recNoInBatch < numRecords && recNoInBatch < batchSize; recNoInBatch++) { + futures.add(recordConsumer.apply(numRecordsCommitted + recNoInBatch)); + } + + // wait and then commit + AsyncUtil.whenAll(futures).get(); + commit(context); + numRecordsCommitted += recNoInBatch; + logger.info("committed batch, numRecordsCommitted = {}", numRecordsCommitted); + } + } + return numRecordsCommitted; + } + + private static void logRecord(final long recNo, @Nonnull final ByteString vectorData) { + if (logger.isInfoEnabled()) { + logger.info("recNo: {}; vectorData: [{})", + recNo, Vector.HalfVector.halfVectorFromBytes(vectorData.toByteArray())); + } + } + + void basicWriteReadTest() throws Exception { + final Random random = new Random(); + saveRecords(false, this::addVectorIndexes, random, 1000); + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addVectorIndexes); + for (long l = 0; l < 1000; l ++) { + FDBStoredRecord rec = recordStore.loadRecord(Tuple.from(l)); + + assertNotNull(rec); + VectorRecord.Builder recordBuilder = + VectorRecord.newBuilder(); + recordBuilder.mergeFrom(rec.getRecord()); + final var record = recordBuilder.build(); + logRecord(record.getRecNo(), record.getVectorData()); + } + commit(context); + } + } + + void basicWriteIndexReadTest() throws Exception { + final Random random = new Random(0); + saveRecords(false, this::addVectorIndexes, random, 1000); + + final DistanceRankValueComparison distanceRankComparison = + new DistanceRankValueComparison(Comparisons.Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL, + new LiteralValue<>(Type.Vector.of(false, 16, 128), + randomVectorData(random, 128)), + new LiteralValue<>(10)); + + final VectorIndexScanComparisons vectorIndexScanComparisons = + new VectorIndexScanComparisons(ScanComparisons.EMPTY, distanceRankComparison, ScanComparisons.EMPTY); + + final var baseRecordType = + Type.Record.fromFieldDescriptorsMap( + Type.Record.toFieldDescriptorMap(VectorRecord.getDescriptor().getFields())); + + final var indexPlan = new RecordQueryIndexPlan("UngroupedVectorIndex", field("recNo"), + vectorIndexScanComparisons, IndexFetchMethod.SCAN_AND_FETCH, + RecordQueryFetchFromPartialRecordPlan.FetchIndexRecords.PRIMARY_KEY, false, false, + Optional.empty(), baseRecordType, QueryPlanConstraint.noConstraint()); + + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addVectorIndexes); + + try (RecordCursorIterator> cursor = executeQuery(indexPlan)) { + while (cursor.hasNext()) { + FDBQueriedRecord rec = cursor.next(); + VectorRecord.Builder myrec = VectorRecord.newBuilder(); + myrec.mergeFrom(Objects.requireNonNull(rec).getRecord()); + System.out.println(myrec); + } + } + + //commit(context); + } + } + + void basicWriteIndexReadGroupedTest() throws Exception { + final Random random = new Random(0); + saveRecords(false, this::addVectorIndexes, random, 1000); + + final DistanceRankValueComparison distanceRankComparison = + new DistanceRankValueComparison(Comparisons.Type.DISTANCE_RANK_LESS_THAN_OR_EQUAL, + new LiteralValue<>(Type.Vector.of(false, 16, 128), + randomVectorData(random, 128)), + new LiteralValue<>(10)); + + final VectorIndexScanComparisons vectorIndexScanComparisons = + new VectorIndexScanComparisons(ScanComparisons.EMPTY, distanceRankComparison, ScanComparisons.EMPTY); + + final var baseRecordType = + Type.Record.fromFieldDescriptorsMap( + Type.Record.toFieldDescriptorMap(VectorRecord.getDescriptor().getFields())); + + final var indexPlan = new RecordQueryIndexPlan("GroupedVectorIndex", field("recNo"), + vectorIndexScanComparisons, IndexFetchMethod.SCAN_AND_FETCH, + RecordQueryFetchFromPartialRecordPlan.FetchIndexRecords.PRIMARY_KEY, false, false, + Optional.empty(), baseRecordType, QueryPlanConstraint.noConstraint()); + + try (FDBRecordContext context = openContext()) { + openRecordStore(context, this::addVectorIndexes); + + try (RecordCursorIterator> cursor = executeQuery(indexPlan)) { + while (cursor.hasNext()) { + FDBQueriedRecord rec = cursor.next(); + VectorRecord.Builder myrec = VectorRecord.newBuilder(); + myrec.mergeFrom(Objects.requireNonNull(rec).getRecord()); + System.out.println(myrec); + } + } + + //commit(context); + } + } +} diff --git a/fdb-record-layer-core/src/test/proto/evolution/test_field_type_change.proto b/fdb-record-layer-core/src/test/proto/evolution/test_field_type_change.proto index eb5805405e..a7194761b2 100644 --- a/fdb-record-layer-core/src/test/proto/evolution/test_field_type_change.proto +++ b/fdb-record-layer-core/src/test/proto/evolution/test_field_type_change.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.evolution.fieldtypechange; option java_package = "com.apple.foundationdb.record.evolution"; option java_outer_classname = "TestFieldTypeChangeProto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; import "test_records_1.proto"; // This needs to match the "MySimpleRecord" definition test_records_1.proto diff --git a/fdb-record-layer-core/src/test/proto/evolution/test_header_as_group.proto b/fdb-record-layer-core/src/test/proto/evolution/test_header_as_group.proto index 22478fc812..54ef33d2e7 100644 --- a/fdb-record-layer-core/src/test/proto/evolution/test_header_as_group.proto +++ b/fdb-record-layer-core/src/test/proto/evolution/test_header_as_group.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.evolution.headergroup; option java_package = "com.apple.foundationdb.record.evolution"; option java_outer_classname = "TestHeaderAsGroupProto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; // This is taken from test_records_with_header.proto but the header has become a group message MyRecord { diff --git a/fdb-record-layer-core/src/test/proto/evolution/test_swap_union_fields.proto b/fdb-record-layer-core/src/test/proto/evolution/test_swap_union_fields.proto index 3b2a6a29e6..fab35adf8b 100644 --- a/fdb-record-layer-core/src/test/proto/evolution/test_swap_union_fields.proto +++ b/fdb-record-layer-core/src/test/proto/evolution/test_swap_union_fields.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.evolution.swap; option java_package = "com.apple.foundationdb.record.evolution"; option java_outer_classname = "TestSwapUnionFieldsProto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; import "test_records_1.proto"; message RecordTypeUnion { diff --git a/fdb-record-layer-core/src/test/proto/expression_tests.proto b/fdb-record-layer-core/src/test/proto/expression_tests.proto index 570a0d920b..4b9f728d69 100644 --- a/fdb-record-layer-core/src/test/proto/expression_tests.proto +++ b/fdb-record-layer-core/src/test/proto/expression_tests.proto @@ -23,7 +23,7 @@ package com.apple.foundationdb.record.metadata; option java_outer_classname = "ExpressionTestsProto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; import "tuple_fields.proto"; message TestScalarFieldAccess { diff --git a/fdb-record-layer-core/src/test/proto/test_no_record_types.proto b/fdb-record-layer-core/src/test/proto/test_no_record_types.proto index 92fcd8a13b..e8f0a4c64a 100644 --- a/fdb-record-layer-core/src/test/proto/test_no_record_types.proto +++ b/fdb-record-layer-core/src/test/proto/test_no_record_types.proto @@ -25,7 +25,7 @@ package com.apple.foundationdb.record.testnorecords; option java_package = "com.apple.foundationdb.record"; option java_outer_classname = "TestNoRecordTypesProto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; message RecordTypeUnion { } diff --git a/fdb-record-layer-core/src/test/proto/test_records_8.proto b/fdb-record-layer-core/src/test/proto/test_records_8.proto index 81e1d2cf3a..d67700e899 100644 --- a/fdb-record-layer-core/src/test/proto/test_records_8.proto +++ b/fdb-record-layer-core/src/test/proto/test_records_8.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.test8; option java_package = "com.apple.foundationdb.record"; option java_outer_classname = "TestRecords8Proto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; message StringRecordId { required string rec_id = 1; diff --git a/fdb-record-layer-core/src/test/proto/test_records_chained_2.proto b/fdb-record-layer-core/src/test/proto/test_records_chained_2.proto index c0165dd764..ec6f877a25 100644 --- a/fdb-record-layer-core/src/test/proto/test_records_chained_2.proto +++ b/fdb-record-layer-core/src/test/proto/test_records_chained_2.proto @@ -25,7 +25,7 @@ option java_package = "com.apple.foundationdb.record"; option java_outer_classname = "TestRecordsChained2Proto"; import "record_metadata_options.proto"; -import "test_records_1.proto"; +//import "test_records_1.proto"; import "test_records_2.proto"; message MyChainedRecord2 { diff --git a/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields.proto b/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields.proto index 2c5f997caa..b9dec24353 100644 --- a/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields.proto +++ b/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.test.duplicateunionfields; option java_package = "com.apple.foundationdb.record"; option java_outer_classname = "TestRecordsDuplicateUnionFields"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; import "test_records_1.proto"; message RecordTypeUnion { diff --git a/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields_reordered.proto b/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields_reordered.proto index b10dbf4d76..f9857dc3e0 100644 --- a/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields_reordered.proto +++ b/fdb-record-layer-core/src/test/proto/test_records_duplicate_union_fields_reordered.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.test.duplicateunionfields.reordered; option java_package = "com.apple.foundationdb.record"; option java_outer_classname = "TestRecordsDuplicateUnionFieldsReordered"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; import "test_records_1.proto"; message RecordTypeUnion { diff --git a/fdb-record-layer-core/src/test/proto/test_records_oneof.proto b/fdb-record-layer-core/src/test/proto/test_records_oneof.proto index 53b6a43c4a..ed25d8ccb0 100644 --- a/fdb-record-layer-core/src/test/proto/test_records_oneof.proto +++ b/fdb-record-layer-core/src/test/proto/test_records_oneof.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.testOneOf; option java_package = "com.apple.foundationdb.record"; option java_outer_classname = "TestRecordsOneOfProto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; message MySimpleRecord { optional int64 rec_no = 1; diff --git a/fdb-record-layer-core/src/test/proto/test_records_transform.proto b/fdb-record-layer-core/src/test/proto/test_records_transform.proto index c20c657542..0c7c7e5732 100644 --- a/fdb-record-layer-core/src/test/proto/test_records_transform.proto +++ b/fdb-record-layer-core/src/test/proto/test_records_transform.proto @@ -24,7 +24,7 @@ package com.apple.foundationdb.record.transform; option java_package = "com.apple.foundationdb.record"; option java_outer_classname = "TestRecordsTransformProto"; -import "record_metadata_options.proto"; +//import "record_metadata_options.proto"; message DefaultTransformMessage { message MessageAa { diff --git a/fdb-record-layer-core/src/test/proto/test_records_vector.proto b/fdb-record-layer-core/src/test/proto/test_records_vector.proto new file mode 100644 index 0000000000..63ff019c9d --- /dev/null +++ b/fdb-record-layer-core/src/test/proto/test_records_vector.proto @@ -0,0 +1,39 @@ +/* + * test_records_multidimensional.proto + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2023 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +syntax = "proto2"; + +package com.apple.foundationdb.record.test.vector; + +option java_package = "com.apple.foundationdb.record.vector"; +option java_outer_classname = "TestRecordsVectorsProto"; + +import "record_metadata_options.proto"; + +option (schema).store_record_versions = true; + +message VectorRecord { + optional int64 rec_no = 1 [(field).primary_key = true]; + optional int32 group_id = 2; + optional bytes vector_data = 3; +} + +message RecordTypeUnion { + optional VectorRecord _VectorRecord = 1; +} diff --git a/fdb-relational-api/src/main/java/com/apple/foundationdb/relational/api/metadata/DataType.java b/fdb-relational-api/src/main/java/com/apple/foundationdb/relational/api/metadata/DataType.java index 56991fb11b..8872a71221 100644 --- a/fdb-relational-api/src/main/java/com/apple/foundationdb/relational/api/metadata/DataType.java +++ b/fdb-relational-api/src/main/java/com/apple/foundationdb/relational/api/metadata/DataType.java @@ -70,6 +70,7 @@ public abstract class DataType { typeCodeJdbcTypeMap.put(Code.ENUM, Types.OTHER); typeCodeJdbcTypeMap.put(Code.UUID, Types.OTHER); typeCodeJdbcTypeMap.put(Code.BYTES, Types.BINARY); + typeCodeJdbcTypeMap.put(Code.VECTOR, Types.OTHER); typeCodeJdbcTypeMap.put(Code.VERSION, Types.BINARY); typeCodeJdbcTypeMap.put(Code.STRUCT, Types.STRUCT); typeCodeJdbcTypeMap.put(Code.ARRAY, Types.ARRAY); @@ -729,6 +730,79 @@ public String toString() { } } + public static final class VectorType extends DataType { + private final int precision; + + private final int dimensions; + + @Nonnull + private final Supplier hashCodeSupplier = Suppliers.memoize(this::computeHashCode); + + private VectorType(final boolean isNullable, int precision, int dimensions) { + super(isNullable, true, Code.VECTOR); + this.precision = precision; + this.dimensions = dimensions; + } + + @Override + public boolean isResolved() { + return true; + } + + @Nonnull + @Override + public DataType withNullable(final boolean isNullable) { + if (isNullable == this.isNullable()) { + return this; + } + return new VectorType(isNullable, precision, dimensions); + } + + @Nonnull + @Override + public DataType resolve(@Nonnull final Map resolutionMap) { + return this; + } + + public int getPrecision() { + return precision; + } + + public int getDimensions() { + return dimensions; + } + + private int computeHashCode() { + return Objects.hash(getCode(), precision, dimensions, isNullable()); + } + + @Override + public int hashCode() { + return hashCodeSupplier.get(); + } + + @Override + public boolean equals(final Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + final VectorType that = (VectorType)o; + return precision == that.precision + && dimensions == that.dimensions + && isNullable() == that.isNullable(); + } + + @Override + public String toString() { + return "vector(p=" + precision + ", d=" + dimensions + ")" + (isNullable() ? " ∪ ∅" : ""); + } + + @Nonnull + public static VectorType of(int precision, int dimensions, boolean isNullable) { + return new VectorType(isNullable, precision, dimensions); + } + } + public static final class VersionType extends DataType { @Nonnull private static final VersionType NOT_NULLABLE_INSTANCE = new VersionType(false); @@ -1504,7 +1578,8 @@ public enum Code { STRUCT, ARRAY, UNKNOWN, - NULL + NULL, + VECTOR, } @SuppressWarnings("PMD.AvoidFieldNameMatchingTypeName") diff --git a/fdb-relational-core/src/main/antlr/RelationalLexer.g4 b/fdb-relational-core/src/main/antlr/RelationalLexer.g4 index b729d6b5b5..43f841233c 100644 --- a/fdb-relational-core/src/main/antlr/RelationalLexer.g4 +++ b/fdb-relational-core/src/main/antlr/RelationalLexer.g4 @@ -112,6 +112,7 @@ GET: 'GET'; GRANT: 'GRANT'; GROUP: 'GROUP'; HAVING: 'HAVING'; +HNSW: 'HNSW'; HIGH_PRIORITY: 'HIGH_PRIORITY'; HISTOGRAM: 'HISTOGRAM'; IF: 'IF'; @@ -160,6 +161,7 @@ OPTION: 'OPTION'; OPTIONAL: 'OPTIONAL'; OPTIONALLY: 'OPTIONALLY'; OR: 'OR'; +ORGANIZED: 'ORGANIZED'; ORDER: 'ORDER'; OUT: 'OUT'; OVER: 'OVER'; @@ -507,6 +509,10 @@ HASH: 'HASH'; HELP: 'HELP'; HOST: 'HOST'; HOSTS: 'HOSTS'; +HNSW_M: 'HNSW_M'; +HNSW_MMAX: 'HNSW_MMAX'; +HNSW_MMAX0: 'HNSW_MMAX0'; +HNSW_EF_CONSTRUCTION: 'HNSW_EF_CONSTRUCTION'; IDENTIFIED: 'IDENTIFIED'; IGNORE_SERVER_IDS: 'IGNORE_SERVER_IDS'; IMPORT: 'IMPORT'; @@ -730,6 +736,13 @@ USER_RESOURCES: 'USER_RESOURCES'; VALIDATION: 'VALIDATION'; VALUE: 'VALUE'; VARIABLES: 'VARIABLES'; +VECTOR: 'VECTOR'; +VECTOR16: 'VECTOR16'; +VECTOR32: 'VECTOR32'; +VECTOR64: 'VECTOR64'; +FLOATVECTOR: 'FLOATVECTOR'; +DOUBLEVECTOR: 'DOUBLEVECTOR'; +HALFVECTOR: 'HALFVECTOR'; VIEW: 'VIEW'; VIRTUAL: 'VIRTUAL'; VISIBLE: 'VISIBLE'; diff --git a/fdb-relational-core/src/main/antlr/RelationalParser.g4 b/fdb-relational-core/src/main/antlr/RelationalParser.g4 index 139de5984c..39ad26350e 100644 --- a/fdb-relational-core/src/main/antlr/RelationalParser.g4 +++ b/fdb-relational-core/src/main/antlr/RelationalParser.g4 @@ -122,9 +122,25 @@ structDefinition ; tableDefinition - : TABLE uid LEFT_ROUND_BRACKET columnDefinition (COMMA columnDefinition)* COMMA primaryKeyDefinition RIGHT_ROUND_BRACKET + : TABLE uid LEFT_ROUND_BRACKET columnDefinition (COMMA columnDefinition)* COMMA primaryKeyDefinition (COMMA organizedByClause)? RIGHT_ROUND_BRACKET ; +organizedByClause + : ORGANIZED BY HNSW '(' embeddingsCol=fullId partitionClause? ')' hnswConfigurations? + ; + +hnswConfigurations + : WITH '(' hnswConfiguration (COMMA hnswConfiguration)* ')' + ; + +hnswConfiguration + : HNSW_M '=' mValue=DECIMAL_LITERAL + | HNSW_MMAX '=' mMaxValue=DECIMAL_LITERAL + | HNSW_MMAX0 '=' mMax0Value=DECIMAL_LITERAL + | HNSW_EF_CONSTRUCTION '=' efConstructionValue=DECIMAL_LITERAL + ; + + columnDefinition : colName=uid columnType ARRAY? columnConstraint? ; @@ -138,7 +154,17 @@ columnType : primitiveType | customType=uid; primitiveType - : BOOLEAN | INTEGER | BIGINT | FLOAT | DOUBLE | STRING | BYTES; + : BOOLEAN | INTEGER | BIGINT | FLOAT | DOUBLE | STRING | BYTES | vectorType; + +vectorType + : VECTOR '(' length=DECIMAL_LITERAL ')' + | VECTOR16 '(' length=DECIMAL_LITERAL ')' + | VECTOR32 '(' length=DECIMAL_LITERAL ')' + | VECTOR64 '(' length=DECIMAL_LITERAL ')' + | HALFVECTOR '(' length=DECIMAL_LITERAL ')' + | FLOATVECTOR '(' length=DECIMAL_LITERAL ')' + | DOUBLEVECTOR '(' length=DECIMAL_LITERAL ')' + ; columnConstraint : nullNotnull #nullColumnConstraint @@ -1100,10 +1126,11 @@ frameRange | expression (PRECEDING | FOLLOWING) ; +*/ + partitionClause : PARTITION BY expression (',' expression)* ; -*/ scalarFunctionName : functionNameBase diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/catalog/RecordLayerStoreSchemaTemplateCatalog.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/catalog/RecordLayerStoreSchemaTemplateCatalog.java index cea1c2128b..c755308d7e 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/catalog/RecordLayerStoreSchemaTemplateCatalog.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/catalog/RecordLayerStoreSchemaTemplateCatalog.java @@ -76,6 +76,14 @@ */ class RecordLayerStoreSchemaTemplateCatalog implements SchemaTemplateCatalog { + @Nonnull + private static final com.google.protobuf.ExtensionRegistry registry = com.google.protobuf.ExtensionRegistry.newInstance(); + + static { + registry.add(com.apple.foundationdb.record.RecordMetaDataOptionsProto.field); + registry.add(com.apple.foundationdb.record.RecordMetaDataOptionsProto.record); + } + @Nonnull private final RecordLayerSchema catalogSchema; @@ -210,7 +218,7 @@ private static SchemaTemplate toSchemaTemplate(@Nonnull final Message message) t // deserialization of the same message over and over again. final Descriptors.Descriptor descriptor = message.getDescriptorForType(); final ByteString bs = Assert.castUnchecked(message.getField(descriptor.findFieldByName(SchemaTemplateSystemTable.METADATA)), ByteString.class); - final RecordMetaData metaData = RecordMetaData.build(RecordMetaDataProto.MetaData.parseFrom(bs.toByteArray())); + final RecordMetaData metaData = RecordMetaData.newBuilder().setRecords(RecordMetaDataProto.MetaData.parseFrom(bs.toByteArray(), registry)).getRecordMetaData(); final String name = message.getField(descriptor.findFieldByName(SchemaTemplateSystemTable.TEMPLATE_NAME)).toString(); int templateVersion = (int) message.getField(descriptor.findFieldByName(SchemaTemplateSystemTable.TEMPLATE_VERSION)); return RecordLayerSchemaTemplate.fromRecordMetadata(metaData, name, templateVersion); diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/metadata/DataTypeUtils.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/metadata/DataTypeUtils.java index 2c3820be2c..d58ad4ed0c 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/metadata/DataTypeUtils.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/metadata/DataTypeUtils.java @@ -61,6 +61,12 @@ public static DataType toRelationalType(@Nonnull final Type type) { final var typeCode = type.getTypeCode(); + // if the type code is BYTES, + if (typeCode.equals(Type.TypeCode.VECTOR)) { + final var vectorType = (Type.Vector)type; + return DataType.VectorType.of(vectorType.getPrecision(), vectorType.getDimensions(), vectorType.isNullable()); + } + if (typeCode == Type.TypeCode.ANY || typeCode == Type.TypeCode.NONE || typeCode == Type.TypeCode.NULL || typeCode == Type.TypeCode.UNKNOWN) { return DataType.UnknownType.instance(); } @@ -112,10 +118,19 @@ public static Type toRecordLayerType(@Nonnull final DataType type) { return primitivesMap.get(type); } + // handle bytes with fixed size and precision. + if (type.getCode().equals(DataType.Code.VECTOR)) { + final var vectorType = (DataType.VectorType)type; + final var precision = vectorType.getPrecision(); + final var dimensions = vectorType.getDimensions(); + return Type.Vector.of(type.isNullable(), precision, dimensions); + } + switch (type.getCode()) { case STRUCT: final var struct = (DataType.StructType) type; - final var fields = struct.getFields().stream().map(field -> Type.Record.Field.of(DataTypeUtils.toRecordLayerType(field.getType()), Optional.of(field.getName()), Optional.of(field.getIndex()))).collect(Collectors.toList()); + final var fields = struct.getFields().stream().map(field -> Type.Record.Field.of(DataTypeUtils.toRecordLayerType(field.getType()), + Optional.of(field.getName()), Optional.of(field.getIndex()))).collect(Collectors.toList()); return Type.Record.fromFieldsWithName(struct.getName(), struct.isNullable(), fields); case ARRAY: final var asArray = (DataType.ArrayType) type; @@ -123,11 +138,13 @@ public static Type toRecordLayerType(@Nonnull final DataType type) { // but since in RL we store the elements as a 'repeated' field, there is not a way to tell if an element is explicitly 'null'. // The current RL behavior loses the nullability information even if the constituent of Type.Array is explicitly marked 'nullable'. Hence, // the check here avoids silently swallowing the requirement. - Assert.thatUnchecked(asArray.getElementType().getCode() == DataType.Code.NULL || !asArray.getElementType().isNullable(), ErrorCode.UNSUPPORTED_OPERATION, "No support for nullable array elements."); + Assert.thatUnchecked(asArray.getElementType().getCode() == DataType.Code.NULL || !asArray.getElementType().isNullable(), + ErrorCode.UNSUPPORTED_OPERATION, "No support for nullable array elements."); return new Type.Array(asArray.isNullable(), toRecordLayerType(asArray.getElementType())); case ENUM: final var asEnum = (DataType.EnumType) type; - final List enumValues = asEnum.getValues().stream().map(v -> new Type.Enum.EnumValue(v.getName(), v.getNumber())).collect(Collectors.toList()); + final List enumValues = asEnum.getValues().stream().map(v -> new Type.Enum.EnumValue(v.getName(), + v.getNumber())).collect(Collectors.toList()); return new Type.Enum(asEnum.isNullable(), enumValues, asEnum.getName()); case UNKNOWN: return new Type.Any(); diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/IndexGenerator.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/IndexGenerator.java index 14c10f53c5..ad5156d75a 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/IndexGenerator.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/IndexGenerator.java @@ -75,6 +75,7 @@ import com.apple.foundationdb.relational.util.NullableArrayUtils; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; @@ -484,7 +485,7 @@ private KeyExpression removeBitmapBucketOffset(@Nonnull KeyExpression groupingEx } @Nonnull - private KeyExpression generate(@Nonnull List fields, @Nonnull Map orderingFunctions) { + private static KeyExpression generate(@Nonnull List fields, @Nonnull Map orderingFunctions) { if (fields.isEmpty()) { return EmptyKeyExpression.EMPTY; } else if (fields.size() == 1) { @@ -514,7 +515,7 @@ private KeyExpression generate(@Nonnull List fields, @Nonnull Map orderingFunctions) { + private static KeyExpression toKeyExpression(Value value, Map orderingFunctions) { var expr = toKeyExpression(value); if (orderingFunctions.containsKey(value)) { return function(orderingFunctions.get(value), expr); @@ -524,7 +525,7 @@ private KeyExpression toKeyExpression(Value value, Map orderingFu } @Nonnull - private KeyExpression toKeyExpression(@Nonnull Value value) { + private static KeyExpression toKeyExpression(@Nonnull Value value) { if (value instanceof VersionValue) { return VersionKeyExpression.VERSION; } else if (value instanceof FieldValue) { @@ -739,12 +740,12 @@ private Value dereference(@Nonnull Value value) { } @Nonnull - private KeyExpression toKeyExpression(@Nonnull List> fields) { + private static KeyExpression toKeyExpression(@Nonnull List> fields) { return toKeyExpression(fields, 0); } @Nonnull - private KeyExpression toKeyExpression(@Nonnull List> fields, int index) { + private static KeyExpression toKeyExpression(@Nonnull List> fields, int index) { Assert.thatUnchecked(!fields.isEmpty()); final var field = fields.get(index); final var keyExpression = toKeyExpression(field.getLeft(), field.getRight()); @@ -779,4 +780,43 @@ private static FieldKeyExpression toKeyExpression(@Nonnull String name, @Nonnull public static IndexGenerator from(@Nonnull RelationalExpression relationalExpression, boolean useLongBasedExtremumEver) { return new IndexGenerator(relationalExpression, useLongBasedExtremumEver); } + + @Nonnull + public static RecordLayerIndex generateHnswIndex(@Nonnull final String tableName, + @Nonnull final Expression embedding, + @Nonnull final Expressions partitionExpressions, + @Nonnull final Map options) { + final var embeddingKeyExpression = toKeyExpression(embedding.getUnderlying()); + final var partitionKeyExpression = generate(ImmutableList.copyOf(partitionExpressions.underlying().iterator()), + Collections.emptyMap()); + final var keyExpression = keyWithValue(concat(partitionKeyExpression, embeddingKeyExpression), + partitionKeyExpression.getColumnSize()); + + final var indexOptions = options.entrySet().stream().map( entry -> { + final var key = entry.getKey(); + final var value = entry.getValue(); + switch (key) { + case "HNSW_M": + return Map.entry(IndexOptions.HNSW_M, value); + case "HNSW_EF_CONSTRUCTION": + return Map.entry(IndexOptions.HNSW_EF_CONSTRUCTION, value); + case "HNSW_MMAX": + return Map.entry(IndexOptions.HNSW_M_MAX, value); + case "HNSW_MMAX0": + return Map.entry(IndexOptions.HNSW_M_MAX_0, value); + default: + throw new RelationalException("unknown HNSW option: " + key, ErrorCode.SYNTAX_ERROR) + .toUncheckedWrappedException(); + } + }).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + + + final var builder = RecordLayerIndex.newBuilder(); + return builder.setIndexType(IndexTypes.VECTOR) + .setName(tableName + "$hnsw") + .setOptions(indexOptions) + .setTableName(tableName) + .setKeyExpression(keyExpression) + .build(); + } } diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/SemanticAnalyzer.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/SemanticAnalyzer.java index 648bbba39d..d391223ef3 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/SemanticAnalyzer.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/SemanticAnalyzer.java @@ -502,6 +502,130 @@ public Optional lookupNestedField(@Nonnull Identifier requestedIdent return Optional.of(nestedAttribute); } + public static final class ParsedTypeInfo { + @Nullable + private final RelationalParser.PrimitiveTypeContext primitiveTypeContext; + + @Nullable + private final Identifier customType; + + private final boolean isNullable; + + private final boolean isRepeated; + + private ParsedTypeInfo(@Nullable final RelationalParser.PrimitiveTypeContext primitiveTypeContext, + @Nullable final Identifier customType, final boolean isNullable, final boolean isRepeated) { + this.primitiveTypeContext = primitiveTypeContext; + this.customType = customType; + this.isNullable = isNullable; + this.isRepeated = isRepeated; + } + + public boolean hasPrimitiveType() { + return primitiveTypeContext != null; + } + + @Nullable + public RelationalParser.PrimitiveTypeContext getPrimitiveTypeContext() { + return primitiveTypeContext; + } + + public boolean hasCustomType() { + return customType != null; + } + + @Nullable + public Identifier getCustomType() { + return customType; + } + + public boolean isNullable() { + return isNullable; + } + + public boolean isRepeated() { + return isRepeated; + } + + @Nonnull + public static ParsedTypeInfo ofPrimitiveType(@Nonnull final RelationalParser.PrimitiveTypeContext primitiveTypeContext, + final boolean isNullable, final boolean isRepeated) { + return new ParsedTypeInfo(primitiveTypeContext, null, isNullable, isRepeated); + } + + @Nonnull + public static ParsedTypeInfo ofCustomType(@Nonnull final Identifier customType, + final boolean isNullable, final boolean isRepeated) { + return new ParsedTypeInfo(null, customType, isNullable, isRepeated); + } + } + + @Nonnull + public DataType lookupType(@Nonnull final ParsedTypeInfo parsedTypeInfo, + @Nonnull final Function> dataTypeProvider) { + DataType type; + final var isNullable = parsedTypeInfo.isNullable(); + if (parsedTypeInfo.hasCustomType()) { + final var typeName = Assert.notNullUnchecked(parsedTypeInfo.getCustomType()).getName(); + final var maybeFound = dataTypeProvider.apply(typeName); + // if we cannot find the type now, mark it, we will try to resolve it later on via a second pass. + type = maybeFound.orElseGet(() -> DataType.UnresolvedType.of(typeName, isNullable)); + } else { + final var primitiveType = Assert.notNullUnchecked(parsedTypeInfo.getPrimitiveTypeContext()); + if (primitiveType.vectorType() != null) { + final var ctx = primitiveType.vectorType(); + int precision = 16; + if (ctx.VECTOR32() != null || ctx.FLOATVECTOR() != null) { + precision = 32; + } else if (ctx.VECTOR64() != null || ctx.DOUBLEVECTOR() != null) { + precision = 64; + } + int length = Assert.castUnchecked(ParseHelpers.parseDecimal(ctx.length.getText()), Integer.class); + type = DataType.VectorType.of(precision, length, isNullable); + } else { + final var primitiveTypeName = parsedTypeInfo.getPrimitiveTypeContext().getText(); + + switch (primitiveTypeName.toUpperCase(Locale.ROOT)) { + case "STRING": + type = isNullable ? DataType.Primitives.NULLABLE_STRING.type() : DataType.Primitives.STRING.type(); + break; + case "INTEGER": + type = isNullable ? DataType.Primitives.NULLABLE_INTEGER.type() : DataType.Primitives.INTEGER.type(); + break; + case "BIGINT": + type = isNullable ? DataType.Primitives.NULLABLE_LONG.type() : DataType.Primitives.LONG.type(); + break; + case "DOUBLE": + type = isNullable ? DataType.Primitives.NULLABLE_DOUBLE.type() : DataType.Primitives.DOUBLE.type(); + break; + case "BOOLEAN": + type = isNullable ? DataType.Primitives.NULLABLE_BOOLEAN.type() : DataType.Primitives.BOOLEAN.type(); + break; + case "BYTES": + type = isNullable ? DataType.Primitives.NULLABLE_BYTES.type() : DataType.Primitives.BYTES.type(); + break; + case "FLOAT": + type = isNullable ? DataType.Primitives.NULLABLE_FLOAT.type() : DataType.Primitives.FLOAT.type(); + break; + default: + Assert.notNullUnchecked(metadataCatalog); + // assume it is a custom type, will fail in upper layers if the type can not be resolved. + // lookup the type (Struct, Table, or Enum) in the schema template metadata under construction. + final var maybeFound = dataTypeProvider.apply(primitiveTypeName); + // if we cannot find the type now, mark it, we will try to resolve it later on via a second pass. + type = maybeFound.orElseGet(() -> DataType.UnresolvedType.of(primitiveTypeName, isNullable)); + break; + } + } + } + + if (parsedTypeInfo.isRepeated()) { + return DataType.ArrayType.from(type.withNullable(false), isNullable); + } else { + return type; + } + } + @Nonnull public DataType lookupType(@Nonnull Identifier typeIdentifier, boolean isNullable, boolean isRepeated, @Nonnull Function> dataTypeProvider) { diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/BaseVisitor.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/BaseVisitor.java index 394c91ec0b..bb2772ee02 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/BaseVisitor.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/BaseVisitor.java @@ -55,6 +55,7 @@ import javax.annotation.Nullable; import java.net.URI; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -352,6 +353,23 @@ public RecordLayerTable visitTableDefinition(@Nonnull RelationalParser.TableDefi return ddlVisitor.visitTableDefinition(ctx); } + @Override + public Object visitOrganizedByClause(final RelationalParser.OrganizedByClauseContext ctx) { + return ddlVisitor.visitOrganizedByClause(ctx); + } + + @Nonnull + @Override + public Map visitHnswConfigurations(final RelationalParser.HnswConfigurationsContext ctx) { + return ddlVisitor.visitHnswConfigurations(ctx); + } + + @Nonnull + @Override + public NonnullPair visitHnswConfiguration(final RelationalParser.HnswConfigurationContext ctx) { + return ddlVisitor.visitHnswConfiguration(ctx); + } + @Nonnull @Override public Object visitColumnDefinition(@Nonnull RelationalParser.ColumnDefinitionContext ctx) { @@ -370,10 +388,14 @@ public DataType visitColumnType(@Nonnull RelationalParser.ColumnTypeContext ctx) return ddlVisitor.visitColumnType(ctx); } - @Nonnull @Override - public DataType visitPrimitiveType(@Nonnull RelationalParser.PrimitiveTypeContext ctx) { - return ddlVisitor.visitPrimitiveType(ctx); + public Object visitPrimitiveType(final RelationalParser.PrimitiveTypeContext ctx) { + return visitChildren(ctx); + } + + @Override + public Object visitVectorType(final RelationalParser.VectorTypeContext ctx) { + return visitChildren(ctx); } @Nonnull @@ -1502,6 +1524,12 @@ public Object visitWindowName(@Nonnull RelationalParser.WindowNameContext ctx) { return visitChildren(ctx); } + @Nonnull + @Override + public Expressions visitPartitionClause(final RelationalParser.PartitionClauseContext ctx) { + return ddlVisitor.visitPartitionClause(ctx); + } + @Nonnull @Override public Object visitScalarFunctionName(@Nonnull RelationalParser.ScalarFunctionNameContext ctx) { diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DdlVisitor.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DdlVisitor.java index 5575626681..07476789c6 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DdlVisitor.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DdlVisitor.java @@ -24,9 +24,11 @@ import com.apple.foundationdb.record.query.plan.cascades.expressions.LogicalSortExpression; import com.apple.foundationdb.record.query.plan.cascades.values.PromoteValue; import com.apple.foundationdb.record.query.plan.cascades.values.ThrowsValue; +import com.apple.foundationdb.record.util.pair.NonnullPair; import com.apple.foundationdb.relational.api.Options; import com.apple.foundationdb.relational.api.ddl.MetadataOperationsFactory; import com.apple.foundationdb.relational.api.exceptions.ErrorCode; +import com.apple.foundationdb.relational.api.exceptions.RelationalException; import com.apple.foundationdb.relational.api.metadata.DataType; import com.apple.foundationdb.relational.api.metadata.InvokedRoutine; import com.apple.foundationdb.relational.generated.RelationalParser; @@ -41,6 +43,7 @@ import com.apple.foundationdb.relational.recordlayer.query.Identifier; import com.apple.foundationdb.relational.recordlayer.query.IndexGenerator; import com.apple.foundationdb.relational.recordlayer.query.LogicalOperator; +import com.apple.foundationdb.relational.recordlayer.query.LogicalOperators; import com.apple.foundationdb.relational.recordlayer.query.PreparedParams; import com.apple.foundationdb.relational.recordlayer.query.ProceduralPlan; import com.apple.foundationdb.relational.recordlayer.query.QueryParser; @@ -48,15 +51,17 @@ import com.apple.foundationdb.relational.recordlayer.query.functions.CompiledSqlFunction; import com.apple.foundationdb.relational.util.Assert; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.antlr.v4.runtime.ParserRuleContext; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.net.URI; import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.Optional; +import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; @@ -99,11 +104,14 @@ public static DdlVisitor of(@Nonnull BaseVisitor delegate, @Override public DataType visitFunctionColumnType(@Nonnull final RelationalParser.FunctionColumnTypeContext ctx) { final var semanticAnalyzer = getDelegate().getSemanticAnalyzer(); + final SemanticAnalyzer.ParsedTypeInfo typeInfo; if (ctx.customType != null) { final var columnType = visitUid(ctx.customType); - return semanticAnalyzer.lookupType(columnType, true, false, metadataBuilder::findType); + typeInfo = SemanticAnalyzer.ParsedTypeInfo.ofCustomType(columnType, true, false); + } else { + typeInfo = SemanticAnalyzer.ParsedTypeInfo.ofPrimitiveType(ctx.primitiveType(), true, false); } - return visitPrimitiveType(ctx.primitiveType()).withNullable(true); + return semanticAnalyzer.lookupType(typeInfo, metadataBuilder::findType); } // TODO: remove @@ -111,20 +119,14 @@ public DataType visitFunctionColumnType(@Nonnull final RelationalParser.Function @Override public DataType visitColumnType(@Nonnull RelationalParser.ColumnTypeContext ctx) { final var semanticAnalyzer = getDelegate().getSemanticAnalyzer(); + final SemanticAnalyzer.ParsedTypeInfo typeInfo; if (ctx.customType != null) { final var columnType = visitUid(ctx.customType); - return semanticAnalyzer.lookupType(columnType, false, false, metadataBuilder::findType); + typeInfo = SemanticAnalyzer.ParsedTypeInfo.ofCustomType(columnType, false, false); + } else { + typeInfo = SemanticAnalyzer.ParsedTypeInfo.ofPrimitiveType(ctx.primitiveType(), false, false); } - return visitPrimitiveType(ctx.primitiveType()); - } - - // TODO: remove - @Nonnull - @Override - public DataType visitPrimitiveType(@Nonnull RelationalParser.PrimitiveTypeContext ctx) { - final var semanticAnalyzer = getDelegate().getSemanticAnalyzer(); - final var primitiveType = Identifier.of(ctx.getText()); - return semanticAnalyzer.lookupType(primitiveType, false, false, ignored -> Optional.empty()); + return semanticAnalyzer.lookupType(typeInfo, metadataBuilder::findType); } /** @@ -147,9 +149,16 @@ public RecordLayerColumn visitColumnDefinition(@Nonnull RelationalParser.ColumnD // but a way to represent it in RecordMetadata. Assert.thatUnchecked(isRepeated || isNullable, ErrorCode.UNSUPPORTED_OPERATION, "NOT NULL is only allowed for ARRAY column type"); containsNullableArray = containsNullableArray || (isRepeated && isNullable); - final var columnTypeId = ctx.columnType().customType != null ? visitUid(ctx.columnType().customType) : Identifier.of(ctx.columnType().getText()); + final var semanticAnalyzer = getDelegate().getSemanticAnalyzer(); - final var columnType = semanticAnalyzer.lookupType(columnTypeId, isNullable, isRepeated, metadataBuilder::findType); + final SemanticAnalyzer.ParsedTypeInfo typeInfo; + if (ctx.columnType().customType != null) { + final var columnType = visitUid(ctx.columnType().customType); + typeInfo = SemanticAnalyzer.ParsedTypeInfo.ofCustomType(columnType, true, isRepeated); + } else { + typeInfo = SemanticAnalyzer.ParsedTypeInfo.ofPrimitiveType(ctx.columnType().primitiveType(), true, isRepeated); + } + final var columnType = semanticAnalyzer.lookupType(typeInfo, metadataBuilder::findType); return RecordLayerColumn.newBuilder().setName(columnId.getName()).setDataType(columnType).build(); } @@ -170,6 +179,74 @@ public RecordLayerTable visitTableDefinition(@Nonnull RelationalParser.TableDefi return tableBuilder.build(); } + public RecordLayerIndex visitIntrinsicIndex(@Nonnull final RelationalParser.TableDefinitionContext ctx) { + return getDelegate().getPlanGenerationContext().withDisabledLiteralProcessing(() -> { + Assert.thatUnchecked(ctx.organizedByClause() != null); + final var ddlCatalog = metadataBuilder.build(); + // parse the index SQL query using the newly constructed metadata. + getDelegate().replaceSchemaTemplate(ddlCatalog); + + // create a synthetic plan fragment comprising the table access only. This is important for resolving the + // components of the clause correctly, including the embedding column and the partitioning columns. + final var tableName = visitUid(ctx.uid()); + final var logicalOperator = LogicalOperator.generateTableAccess(tableName, ImmutableSet.of(), getDelegate().getSemanticAnalyzer()); + + final var organizedByClause = ctx.organizedByClause(); + + final var embeddingColumnId = visitFullId(organizedByClause.embeddingsCol); + final var embeddingColumn = getDelegate().getSemanticAnalyzer().resolveIdentifier(embeddingColumnId, LogicalOperators.ofSingle(logicalOperator)); + + getDelegate().pushPlanFragment().setOperator(logicalOperator); + final var partitionExpressions = (organizedByClause.partitionClause() == null) + ? Expressions.empty() + : visitPartitionClause(organizedByClause.partitionClause()); + getDelegate().popPlanFragment(); + final var indexOptions = (organizedByClause.hnswConfigurations() == null) + ? ImmutableMap.of() + : visitHnswConfigurations(organizedByClause.hnswConfigurations()); + + return IndexGenerator.generateHnswIndex(tableName.getName(), embeddingColumn, partitionExpressions, indexOptions); + }); + } + + @Nullable + @Override + public Object visitOrganizedByClause(final RelationalParser.OrganizedByClauseContext ctx) { + return null; // postpone processing, it should start exactly after the table is fully resolved. + } + + @Nonnull + @Override + public Map visitHnswConfigurations(final RelationalParser.HnswConfigurationsContext ctx) { + return ctx.hnswConfiguration().stream().map(this::visitHnswConfiguration) + .collect(ImmutableMap.toImmutableMap( + NonnullPair::getLeft, + NonnullPair::getRight, + (existing, replacement) -> { + throw new RelationalException("duplicate configuration '" + existing + "'", ErrorCode.SYNTAX_ERROR) + .toUncheckedWrappedException(); + })); + } + + @Nonnull + @Override + public NonnullPair visitHnswConfiguration(final RelationalParser.HnswConfigurationContext ctx) { + if (ctx.mValue != null) { + return NonnullPair.of("HNSW_M", ctx.mValue.getText()); + } + if (ctx.efConstructionValue != null) { + return NonnullPair.of("HNSW_EF_CONSTRUCTION", ctx.efConstructionValue.getText()); + } + if (ctx.mMaxValue != null) { + return NonnullPair.of("HNSW_MMAX", ctx.mMaxValue.getText()); + } + if (ctx.mMax0Value != null) { + return NonnullPair.of("HNSW_MMAX0", ctx.mMax0Value.getText()); + } + Assert.failUnchecked(ErrorCode.SYNTAX_ERROR, "unknown hnsw configuration" + ctx); + return null; + } + @Nonnull @Override public RecordLayerTable visitStructDefinition(@Nonnull RelationalParser.StructDefinitionContext ctx) { @@ -251,9 +328,16 @@ public ProceduralPlan visitCreateSchemaTemplateStatement(@Nonnull RelationalPars indexClauses.add(templateClause.indexDefinition()); } } + final var indexes = ImmutableList.builder(); structClauses.build().stream().map(this::visitStructDefinition).map(RecordLayerTable::getDatatype).forEach(metadataBuilder::addAuxiliaryType); - tableClauses.build().stream().map(this::visitTableDefinition).forEach(metadataBuilder::addTable); - final var indexes = indexClauses.build().stream().map(this::visitIndexDefinition).collect(ImmutableList.toImmutableList()); + for (final var tableClause : tableClauses.build()) { + metadataBuilder.addTable(visitTableDefinition(tableClause)); + if (tableClause.organizedByClause() != null) { + indexes.add(visitIntrinsicIndex(tableClause)); + } + } + + indexClauses.build().stream().map(this::visitIndexDefinition).forEach(indexes::add); // TODO: this is currently relying on the lexical order of the function to resolve function dependencies which // is limited. functionClauses.build().forEach(functionClause -> { @@ -261,7 +345,7 @@ public ProceduralPlan visitCreateSchemaTemplateStatement(@Nonnull RelationalPars functionClause.routineBody(), metadataBuilder.build()); metadataBuilder.addInvokedRoutine(invokedRoutine); }); - for (final var index : indexes) { + for (final var index : indexes.build()) { final var table = metadataBuilder.extractTable(index.getTableName()); final var tableWithIndex = RecordLayerTable.Builder.from(table).addIndex(index).build(); metadataBuilder.addTable(tableWithIndex); @@ -505,4 +589,11 @@ public DataType visitReturnsType(@Nonnull RelationalParser.ReturnsTypeContext ct public Boolean visitNullColumnConstraint(@Nonnull RelationalParser.NullColumnConstraintContext ctx) { return ctx.nullNotnull().NOT() == null; } + + @Nonnull + @Override + public Expressions visitPartitionClause(final RelationalParser.PartitionClauseContext ctx) { + return Expressions.of(ctx.expression().stream().map(expContext -> + Assert.castUnchecked(visit(expContext), Expression.class)).collect(ImmutableList.toImmutableList())); + } } diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DelegatingVisitor.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DelegatingVisitor.java index 4f2663f3b9..f3bb965f90 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DelegatingVisitor.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/DelegatingVisitor.java @@ -45,6 +45,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.List; +import java.util.Map; import java.util.Set; @API(API.Status.EXPERIMENTAL) @@ -182,6 +183,24 @@ public RecordLayerTable visitTableDefinition(@Nonnull RelationalParser.TableDefi return getDelegate().visitTableDefinition(ctx); } + @Nullable + @Override + public Object visitOrganizedByClause(final RelationalParser.OrganizedByClauseContext ctx) { + return getDelegate().visitOrganizedByClause(ctx); + } + + @Nonnull + @Override + public Map visitHnswConfigurations(final RelationalParser.HnswConfigurationsContext ctx) { + return getDelegate().visitHnswConfigurations(ctx); + } + + @Nonnull + @Override + public NonnullPair visitHnswConfiguration(final RelationalParser.HnswConfigurationContext ctx) { + return getDelegate().visitHnswConfiguration(ctx); + } + @Nonnull @Override public Object visitColumnDefinition(@Nonnull RelationalParser.ColumnDefinitionContext ctx) { @@ -200,12 +219,16 @@ public DataType visitColumnType(@Nonnull RelationalParser.ColumnTypeContext ctx) return getDelegate().visitColumnType(ctx); } - @Nonnull @Override - public DataType visitPrimitiveType(@Nonnull RelationalParser.PrimitiveTypeContext ctx) { + public Object visitPrimitiveType(final RelationalParser.PrimitiveTypeContext ctx) { return getDelegate().visitPrimitiveType(ctx); } + @Override + public Object visitVectorType(final RelationalParser.VectorTypeContext ctx) { + return getDelegate().visitVectorType(ctx); + } + @Nonnull @Override public Boolean visitNullColumnConstraint(@Nonnull RelationalParser.NullColumnConstraintContext ctx) { @@ -1341,6 +1364,12 @@ public Object visitWindowName(@Nonnull RelationalParser.WindowNameContext ctx) { return getDelegate().visitWindowName(ctx); } + @Nonnull + @Override + public Expressions visitPartitionClause(final RelationalParser.PartitionClauseContext ctx) { + return getDelegate().visitPartitionClause(ctx); + } + @Nonnull @Override public Object visitScalarFunctionName(@Nonnull RelationalParser.ScalarFunctionNameContext ctx) { diff --git a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/TypedVisitor.java b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/TypedVisitor.java index bf13ecc143..6ebd5c1bb9 100644 --- a/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/TypedVisitor.java +++ b/fdb-relational-core/src/main/java/com/apple/foundationdb/relational/recordlayer/query/visitors/TypedVisitor.java @@ -41,6 +41,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -134,21 +135,29 @@ public interface TypedVisitor extends RelationalParserVisitor { @Override RecordLayerTable visitTableDefinition(@Nonnull RelationalParser.TableDefinitionContext ctx); + @Nullable + @Override + Object visitOrganizedByClause(RelationalParser.OrganizedByClauseContext ctx); + @Nonnull @Override - Object visitColumnDefinition(@Nonnull RelationalParser.ColumnDefinitionContext ctx); + Map visitHnswConfigurations(RelationalParser.HnswConfigurationsContext ctx); @Nonnull @Override - DataType visitFunctionColumnType(@Nonnull RelationalParser.FunctionColumnTypeContext ctx); + NonnullPair visitHnswConfiguration(RelationalParser.HnswConfigurationContext ctx); @Nonnull @Override - DataType visitColumnType(@Nonnull RelationalParser.ColumnTypeContext ctx); + Object visitColumnDefinition(@Nonnull RelationalParser.ColumnDefinitionContext ctx); @Nonnull @Override - DataType visitPrimitiveType(@Nonnull RelationalParser.PrimitiveTypeContext ctx); + DataType visitFunctionColumnType(@Nonnull RelationalParser.FunctionColumnTypeContext ctx); + + @Nonnull + @Override + DataType visitColumnType(@Nonnull RelationalParser.ColumnTypeContext ctx); @Nonnull @Override @@ -850,6 +859,10 @@ public interface TypedVisitor extends RelationalParserVisitor { @Override Object visitWindowName(@Nonnull RelationalParser.WindowNameContext ctx); + @Nonnull + @Override + Expressions visitPartitionClause(RelationalParser.PartitionClauseContext ctx); + @Nonnull @Override Object visitScalarFunctionName(@Nonnull RelationalParser.ScalarFunctionNameContext ctx); diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/IndexTest.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/IndexTest.java index c61ada2fc4..a5118071a3 100644 --- a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/IndexTest.java +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/api/ddl/IndexTest.java @@ -20,7 +20,9 @@ package com.apple.foundationdb.relational.api.ddl; +import com.apple.foundationdb.record.metadata.IndexOptions; import com.apple.foundationdb.record.metadata.IndexTypes; +import com.apple.foundationdb.record.metadata.expressions.FunctionKeyExpression; import com.apple.foundationdb.record.metadata.expressions.GroupingKeyExpression; import com.apple.foundationdb.record.metadata.expressions.KeyExpression; import com.apple.foundationdb.record.metadata.expressions.KeyWithValueExpression; @@ -40,6 +42,7 @@ import com.apple.foundationdb.relational.util.NullableArrayUtils; import com.apple.foundationdb.relational.utils.SimpleDatabaseRule; import com.apple.foundationdb.relational.utils.TestSchemas; +import com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; @@ -111,6 +114,11 @@ private void indexIs(@Nonnull final String stmt, @Nonnull final KeyExpression ex private void indexIs(@Nonnull final String stmt, @Nonnull final KeyExpression expectedKey, @Nonnull final String indexType, @Nonnull final Consumer validator) throws Exception { + indexIs(stmt, expectedKey, indexType, "MV1", validator); + } + + private void indexIs(@Nonnull final String stmt, @Nonnull final KeyExpression expectedKey, @Nonnull final String indexType, + @Nonnull final String indexName, @Nonnull final Consumer validator) throws Exception { shouldWorkWithInjectedFactory(stmt, new AbstractMetadataOperationsFactory() { @Nonnull @Override @@ -123,7 +131,7 @@ public ConstantAction getSaveSchemaTemplateConstantAction(@Nonnull final SchemaT Assertions.assertEquals(1, table.getIndexes().size(), "Incorrect number of indexes!"); final Index index = Assert.optionalUnchecked(table.getIndexes().stream().findFirst()); Assertions.assertInstanceOf(RecordLayerIndex.class, index); - Assertions.assertEquals("MV1", index.getName(), "Incorrect index name!"); + Assertions.assertEquals(indexName, index.getName(), "Incorrect index name!"); Assertions.assertEquals(indexType, index.getIndexType()); final KeyExpression actualKey = KeyExpression.fromProto(((RecordLayerIndex) index).getKeyExpression().toKeyExpression()); Assertions.assertEquals(expectedKey, actualKey); @@ -896,7 +904,7 @@ void createIndexWithOrderByInFromSelect() throws Exception { final String stmt = "CREATE SCHEMA TEMPLATE test_template " + "CREATE TYPE AS STRUCT A(x bigint) " + "CREATE TABLE T(p bigint, a A array, primary key(p))" + - "CREATE INDEX mv1 AS SELECT SQ.x from T AS t, (select M.x from t.a AS M order by M.x) SQ"; + "CREATE INDEX mv1 AS SELECT SQ.x from T AS t, (select Q.x from t.a AS Q order by Q.x) SQ"; shouldFailWith(stmt, ErrorCode.UNSUPPORTED_OPERATION, "order by is not supported in subquery"); } @@ -927,4 +935,50 @@ void createIndexWithOrderByMixedDirection() throws Exception { concat(field("COL1"), function("order_desc_nulls_last", field("COL2")), function("order_asc_nulls_last", field("COL3"))), IndexTypes.VALUE); } + + @Test + void createHnswIndex() throws Exception { + final String stmt = "CREATE SCHEMA TEMPLATE test_template " + + "create table photos(zone string, recordId string, " + + "embedding vector(768), primary key (zone, recordId), organized by hnsw(embedding partition by zone) " + + "with (hnsw_m = 10, hnsw_ef_construction = 5))"; + indexIs(stmt, + new KeyWithValueExpression(concat(field("ZONE"), field("EMBEDDING")), 1), + IndexTypes.VECTOR, "PHOTOS$hnsw", idx -> { + Assertions.assertInstanceOf(RecordLayerIndex.class, idx); + Assertions.assertEquals(ImmutableMap.of(IndexOptions.HNSW_M, "10", IndexOptions.HNSW_EF_CONSTRUCTION, "5"), + ((RecordLayerIndex)idx).getOptions()); + }); + } + + @Test + void createHnswIndexMultiplePartitions() throws Exception { + final String stmt = "CREATE SCHEMA TEMPLATE test_template " + + "create table photos(zone string, recordId string, name string," + + "embedding vector(768), primary key (zone, recordId), organized by hnsw(embedding partition by zone, name) " + + "with (hnsw_m = 10, hnsw_ef_construction = 5))"; + indexIs(stmt, + new KeyWithValueExpression(concat(field("ZONE"), field("NAME"), field("EMBEDDING")), 2), + IndexTypes.VECTOR, "PHOTOS$hnsw", idx -> { + Assertions.assertInstanceOf(RecordLayerIndex.class, idx); + Assertions.assertEquals(ImmutableMap.of(IndexOptions.HNSW_M, "10", IndexOptions.HNSW_EF_CONSTRUCTION, "5"), + ((RecordLayerIndex)idx).getOptions()); + }); + } + + @Test + void createHnswIndexPartitionArithmeticExpression() throws Exception { + final String stmt = "CREATE SCHEMA TEMPLATE test_template " + + "create table photos(zone string, recordId string, name string," + + "embedding vector(768), primary key (zone, recordId), organized by hnsw(embedding partition by zone + 3, name) " + + "with (hnsw_m = 10, hnsw_ef_construction = 5))"; + indexIs(stmt, + new KeyWithValueExpression(concat(FunctionKeyExpression.create("add", concat(field("ZONE"), value(3))) , + field("NAME"), field("EMBEDDING")), 2), + IndexTypes.VECTOR, "PHOTOS$hnsw", idx -> { + Assertions.assertInstanceOf(RecordLayerIndex.class, idx); + Assertions.assertEquals(ImmutableMap.of(IndexOptions.HNSW_M, "10", IndexOptions.HNSW_EF_CONSTRUCTION, "5"), + ((RecordLayerIndex)idx).getOptions()); + }); + } } diff --git a/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/query/VectorTypeTest.java b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/query/VectorTypeTest.java new file mode 100644 index 0000000000..b0e8d55d36 --- /dev/null +++ b/fdb-relational-core/src/test/java/com/apple/foundationdb/relational/recordlayer/query/VectorTypeTest.java @@ -0,0 +1,69 @@ +/* + * VectorTypeTest.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.relational.recordlayer.query; + +import com.apple.foundationdb.relational.api.StructResultSetMetaData; +import com.apple.foundationdb.relational.api.metadata.DataType; +import com.apple.foundationdb.relational.recordlayer.EmbeddedRelationalExtension; +import com.apple.foundationdb.relational.utils.Ddl; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nonnull; +import java.net.URI; +import java.util.stream.Stream; + +public class VectorTypeTest { + @RegisterExtension + @Order(0) + public final EmbeddedRelationalExtension relationalExtension = new EmbeddedRelationalExtension(); + + + @Nonnull + public static Stream vectorArguments() { + return Stream.of( + Arguments.of("halfvector(512)", DataType.VectorType.of(16, 512, true)), + Arguments.of("vector16(512)", DataType.VectorType.of(16, 512, true)), + Arguments.of("doublevector(1024)", DataType.VectorType.of(64, 1024, true)), + Arguments.of("vector32(768)", DataType.VectorType.of(32, 768, true)), + Arguments.of("vector(256)", DataType.VectorType.of(16, 256, true))); + } + + @ParameterizedTest(name = "{0} evaluates to data type {1}") + @MethodSource("vectorArguments") + void vectorTest(@Nonnull final String ddlType, @Nonnull final DataType expectedType) throws Exception { + final String schemaTemplate = "create table t1(id bigint, col1 " + ddlType + ", primary key(id))"; + try (var ddl = Ddl.builder().database(URI.create("/TEST/QT")).relationalExtension(relationalExtension).schemaTemplate(schemaTemplate).build()) { + try (var statement = ddl.setSchemaAndGetConnection().createStatement()) { + statement.execute("select * from t1"); + final var metadata = statement.getResultSet().getMetaData(); + Assertions.assertThat(metadata).isInstanceOf(StructResultSetMetaData.class); + final var relationalMetadata = (StructResultSetMetaData)metadata; + final var type = relationalMetadata.getRelationalDataType().getFields().get(1).getType(); + Assertions.assertThat(type).isEqualTo(expectedType); + } + } + } +} diff --git a/gradle/codequality/pmd-rules.xml b/gradle/codequality/pmd-rules.xml index 500ef17c69..4d8745d875 100644 --- a/gradle/codequality/pmd-rules.xml +++ b/gradle/codequality/pmd-rules.xml @@ -16,6 +16,7 @@ + diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index a9abaf9a20..0eabf0d87d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -37,6 +37,7 @@ generatedAnnotation = "1.3.2" grpc = "1.64.1" grpc-commonProtos = "2.37.0" guava = "33.3.1-jre" +half4j = "0.0.2" h2 = "1.3.148" icu = "69.1" lucene = "8.11.1" @@ -95,6 +96,7 @@ grpc-services = { module = "io.grpc:grpc-services", version.ref = "grpc" } grpc-stub = { module = "io.grpc:grpc-stub", version.ref = "grpc" } grpc-util = { module = "io.grpc:grpc-util", version.ref = "grpc" } guava = { module = "com.google.guava:guava", version.ref = "guava" } +half4j = { module = "com.christianheina.langx:half4j", version.ref = "half4j"} icu = { module = "com.ibm.icu:icu4j", version.ref = "icu" } javaPoet = { module = "com.squareup:javapoet", version.ref = "javaPoet" } jsr305 = { module = "com.google.code.findbugs:jsr305", version.ref = "jsr305" } diff --git a/gradle/scripts/log4j-test.properties b/gradle/scripts/log4j-test.properties index 447ee2f55a..1ae7583751 100644 --- a/gradle/scripts/log4j-test.properties +++ b/gradle/scripts/log4j-test.properties @@ -26,7 +26,7 @@ appender.console.name = STDOUT appender.console.layout.type = PatternLayout appender.console.layout.pattern = %d [%level] %logger{1.} - %m %X%n%ex{full} -rootLogger.level = debug +rootLogger.level = info rootLogger.appenderRefs = stdout rootLogger.appenderRef.stdout.ref = STDOUT diff --git a/yaml-tests/src/test/java/YamlIntegrationTests.java b/yaml-tests/src/test/java/YamlIntegrationTests.java index 138f9bcc75..e877d72973 100644 --- a/yaml-tests/src/test/java/YamlIntegrationTests.java +++ b/yaml-tests/src/test/java/YamlIntegrationTests.java @@ -267,8 +267,8 @@ public void enumTest(YamlTest.Runner runner) throws Exception { } @TestTemplate - public void uuidProtoTest(YamlTest.Runner runner) throws Exception { - runner.runYamsql("uuid-proto.yamsql"); + public void uuidTest(YamlTest.Runner runner) throws Exception { + runner.runYamsql("uuid.yamsql"); } @TestTemplate