Skip to content

Commit 531f736

Browse files
Reduce the number of vector allocations in BuildScoreProvider.pqBuilderScoreProvider (#419)
Reduce the number of vector allocations in BuildScoreProvider.pqBuildScoreProvider. Addresses #418
1 parent d312aa7 commit 531f736

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,11 @@ public SearchScoreProvider diversityProviderFor(int node1) {
138138
* with reranking performed using RandomAccessVectorValues (which is intended to be
139139
* InlineVectorValues for building incrementally, but should technically
140140
* work with any RAVV implementation).
141+
* This class is not thread safe, we should never publish its results to another thread.
141142
*/
142143
static BuildScoreProvider pqBuildScoreProvider(VectorSimilarityFunction vsf, PQVectors pqv) {
143144
int dimension = pqv.getOriginalSize() / Float.BYTES;
145+
final ThreadLocal<VectorFloat<?>> reusableVector = ThreadLocal.withInitial(() -> vts.createFloatVector(dimension));;
144146

145147
return new BuildScoreProvider() {
146148
@Override
@@ -153,15 +155,15 @@ public SearchScoreProvider diversityProviderFor(int node1) {
153155
// like searchProviderFor, this skips reranking; unlike sPF, it uses pqv.scoreFunctionFor
154156
// instead of precomputedScoreFunctionFor; since we only perform a few dozen comparisons
155157
// during diversity computation, this is cheaper than precomputing a lookup table
156-
VectorFloat<?> v1 = vts.createFloatVector(dimension);
158+
VectorFloat<?> v1 = reusableVector.get();
157159
pqv.getCompressor().decode(pqv.get(node1), v1);
158160
var asf = pqv.scoreFunctionFor(v1, vsf); // not precomputed!
159161
return new SearchScoreProvider(asf);
160162
}
161163

162164
@Override
163165
public SearchScoreProvider searchProviderFor(int node1) {
164-
VectorFloat<?> decoded = vts.createFloatVector(dimension);
166+
VectorFloat<?> decoded = reusableVector.get();
165167
pqv.getCompressor().decode(pqv.get(node1), decoded);
166168
return searchProviderFor(decoded);
167169
}

0 commit comments

Comments
 (0)