Skip to content

Commit 06d868d

Browse files
finnroblinkaivalnp
authored andcommitted
Fix MergedByteVectorValues internal ordinal tracking (#15553)
Move internal ordinal tracking in `MergedByteVectorValues` from `vectorValue` -> `nextDoc` to allow loading only a subset of vectors during iteration.
1 parent db2ebf7 commit 06d868d

File tree

3 files changed

+115
-8
lines changed

3 files changed

+115
-8
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ Bug Fixes
232232
* GITHUB#15554: Fix tessellator failure by preferring the shared vertex that is the leftmost vertex of the hole
233233
(Ignacio Vera)
234234

235+
* GITHUB#15553: Fix MergedByteVectorValues lastOrd behavior to enable partitioning of vector values
236+
(Finn Roblin, Dooyong Kim)
237+
235238
Other
236239
---------------------
237240
* GITHUB#15237: Fix SmartChinese to only deserialize dictionary data from classpath with a native array

lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,12 @@ private void finishMerge(MergeState mergeState) throws IOException {
121121
}
122122
}
123123

124-
/** Tracks state of one sub-reader that we are merging */
125-
private static class FloatVectorValuesSub extends DocIDMerger.Sub {
124+
/**
125+
* Tracks state of one sub-reader of float vectors that we are merging.
126+
*
127+
* @lucene.internal
128+
*/
129+
static class FloatVectorValuesSub extends DocIDMerger.Sub {
126130

127131
final FloatVectorValues values;
128132
final KnnVectorValues.DocIndexIterator iterator;
@@ -144,7 +148,12 @@ public int index() {
144148
}
145149
}
146150

147-
private static class ByteVectorValuesSub extends DocIDMerger.Sub {
151+
/**
152+
* Tracks state of one sub-reader of byte vectors that we are merging.
153+
*
154+
* @lucene.internal
155+
*/
156+
static class ByteVectorValuesSub extends DocIDMerger.Sub {
148157

149158
final ByteVectorValues values;
150159
final KnnVectorValues.DocIndexIterator iterator;
@@ -303,6 +312,11 @@ public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeS
303312
mergeState);
304313
}
305314

315+
/**
316+
* Unified view over several segments containing float vector values.
317+
*
318+
* @lucene.internal
319+
*/
306320
static class MergedFloat32VectorValues extends FloatVectorValues {
307321
private final List<FloatVectorValuesSub> subs;
308322
private final DocIDMerger<FloatVectorValuesSub> docIdMerger;
@@ -311,7 +325,8 @@ static class MergedFloat32VectorValues extends FloatVectorValues {
311325
private int lastOrd = -1;
312326
FloatVectorValuesSub current;
313327

314-
private MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
328+
// package-private for testing
329+
MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
315330
throws IOException {
316331
this.subs = subs;
317332
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
@@ -401,6 +416,11 @@ public FloatVectorValues copy() {
401416
}
402417
}
403418

419+
/**
420+
* Unified view over several segments containing byte vector values.
421+
*
422+
* @lucene.internal
423+
*/
404424
static class MergedByteVectorValues extends ByteVectorValues {
405425
private final List<ByteVectorValuesSub> subs;
406426
private final DocIDMerger<ByteVectorValuesSub> docIdMerger;
@@ -410,7 +430,8 @@ static class MergedByteVectorValues extends ByteVectorValues {
410430
private int docId = -1;
411431
ByteVectorValuesSub current;
412432

413-
private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
433+
// package-private for testing
434+
MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
414435
throws IOException {
415436
this.subs = subs;
416437
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
@@ -423,11 +444,9 @@ private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeS
423444

424445
@Override
425446
public byte[] vectorValue(int ord) throws IOException {
426-
if (ord != lastOrd + 1) {
447+
if (ord != lastOrd) {
427448
throw new IllegalStateException(
428449
"only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd);
429-
} else {
430-
lastOrd = ord;
431450
}
432451
return current.values.vectorValue(current.index());
433452
}
@@ -455,6 +474,7 @@ public int nextDoc() throws IOException {
455474
index = NO_MORE_DOCS;
456475
} else {
457476
docId = current.mappedDocID;
477+
++lastOrd;
458478
++index;
459479
}
460480
return docId;
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.codecs;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import org.apache.lucene.index.ByteVectorValues;
22+
import org.apache.lucene.index.FloatVectorValues;
23+
import org.apache.lucene.index.MergeState;
24+
import org.apache.lucene.tests.util.LuceneTestCase;
25+
26+
/** Tests for merged vector values to ensure lastOrd is properly incremented during iteration. */
27+
public class TestMergedVectorValues extends LuceneTestCase {
28+
29+
/**
30+
* Test that skipping vectors in MergedByteVectorValues via nextDoc() and then loading a
31+
* subsequent vector via vectorValue() works correctly.
32+
*/
33+
public void testSkipsInMergedByteVectorValues() throws IOException {
34+
// Data
35+
List<byte[]> vectors = List.of(new byte[] {0}, new byte[] {1});
36+
37+
// Setup
38+
KnnVectorsWriter.ByteVectorValuesSub sub =
39+
new KnnVectorsWriter.ByteVectorValuesSub(x -> x, ByteVectorValues.fromBytes(vectors, 1));
40+
MergeState state =
41+
new MergeState(
42+
null, null, null, null, null, null, null, null, null, null, null, null, null, null,
43+
null, false);
44+
45+
// Run the test
46+
ByteVectorValues values =
47+
new KnnVectorsWriter.MergedVectorValues.MergedByteVectorValues(List.of(sub), state);
48+
49+
// Skip doc 0 and load doc 1
50+
values.iterator().nextDoc(); // doc 0
51+
values.iterator().nextDoc(); // doc 1
52+
53+
// Read vector for doc 1
54+
assertArrayEquals(vectors.get(1), values.vectorValue(1));
55+
}
56+
57+
/**
58+
* Test that skipping vectors in MergedFloat32VectorValues via nextDoc() and then loading a
59+
* subsequent vector via vectorValue() works correctly.
60+
*/
61+
public void testSkipsInMergedFloat32VectorValues() throws IOException {
62+
// Data
63+
List<float[]> vectors = List.of(new float[] {0.0f}, new float[] {1.0f});
64+
65+
// Setup
66+
KnnVectorsWriter.FloatVectorValuesSub sub =
67+
new KnnVectorsWriter.FloatVectorValuesSub(x -> x, FloatVectorValues.fromFloats(vectors, 1));
68+
MergeState state =
69+
new MergeState(
70+
null, null, null, null, null, null, null, null, null, null, null, null, null, null,
71+
null, false);
72+
73+
// Run the test
74+
FloatVectorValues values =
75+
new KnnVectorsWriter.MergedVectorValues.MergedFloat32VectorValues(List.of(sub), state);
76+
77+
// Skip doc 0 and load doc 1
78+
values.iterator().nextDoc(); // doc 0
79+
values.iterator().nextDoc(); // doc 1
80+
81+
// Read vector for doc 1
82+
assertArrayEquals(vectors.get(1), values.vectorValue(1), 0.0f);
83+
}
84+
}

0 commit comments

Comments
 (0)