diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java new file mode 100644 index 000000000..a5ffcfa31 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java @@ -0,0 +1,69 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.util.Arrays; + +public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues { + private final RandomAccessVectorValues ravv; + private final int[] graphToRavvOrdMap; + + /** + * Remaps a RAVV to a different set of ordinals. This is useful when the ordinals used by the graph + * do not match the ordinals used by the RAVV. + * + * @param ravv the RAVV to remap + * @param graphToRavvOrdMap a mapping from the graph's ordinals to the RAVV's ordinals where + * graphToRavvOrdMap[i] is the RAVV ordinal corresponding to graph ordinal i. + */ + public RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap) { + this.ravv = ravv; + this.graphToRavvOrdMap = graphToRavvOrdMap; + } + + @Override + public int size() { + return graphToRavvOrdMap.length; + } + + @Override + public int dimension() { + return ravv.dimension(); + } + + @Override + public VectorFloat getVector(int node) { + return ravv.getVector(graphToRavvOrdMap[node]); + } + + @Override + public boolean isValueShared() { + return ravv.isValueShared(); + } + + @Override + public RandomAccessVectorValues copy() { + return new RemappedRandomAccessVectorValues(ravv.copy(), Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length)); + } + + @Override + public void getVectorInto(int node, VectorFloat result, int offset) { + ravv.getVectorInto(graphToRavvOrdMap[node], result, offset); + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index 0ffdf72eb..f0b184e67 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.graph.similarity; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.RemappedRandomAccessVectorValues; import io.github.jbellis.jvector.quantization.BQVectors; import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; @@ -25,8 +26,6 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.util.stream.IntStream; - /** * Encapsulates comparing node distances for GraphIndexBuilder. */ @@ -88,15 +87,15 @@ public interface BuildScoreProvider { * * Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function. */ - static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { - return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction); + static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) { + return randomAccessScoreProvider(new RemappedRandomAccessVectorValues(ravv, graphToRavvOrdMap), similarityFunction); } /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. * graphToRavvOrdMap maps graph node IDs to ravv ordinals. */ - static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) { + static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without // colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared. var vectors = ravv.threadLocalSupplier(); @@ -125,22 +124,22 @@ public VectorFloat approximateCentroid() { @Override public SearchScoreProvider searchProviderFor(VectorFloat vector) { var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(vector, similarityFunction, vc); } @Override public SearchScoreProvider searchProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); + var v = randomAccessVectorValues.getVector(node1); return searchProviderFor(v); } @Override public SearchScoreProvider diversityProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); + var v = randomAccessVectorValues.getVector(node1); var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(v, similarityFunction, vc); } }; } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index 716621d21..59b248584 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -156,4 +156,19 @@ public void testSaveAndLoad() throws IOException { } assertGraphEquals(graph, builder.graph); } + + // Because RandomAccessVectorValues is exposed in such a way that it allows for subsequent additions to the + // vector source, we need to ensure that GraphIndexBuilder can handle this. + @Test + public void testAddNodesToVectorValuesIteratively() throws IOException { + int dimension = randomIntBetween(2, 32); + var mutableVectors = new ArrayList>(); + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(mutableVectors, dimension); + try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, true)) { + for (int i = 0; i < 10; i++) { + mutableVectors.add(TestUtil.randomVector(random(), dimension)); + builder.addGraphNode(i, ravv.getVector(i)); + } + } + } }