Skip to content

Commit ee095bc

Browse files
committed
[core] Refactor VectorSearch to keep required fields
1 parent 3d41be6 commit ee095bc

File tree

10 files changed

+63
-104
lines changed

10 files changed

+63
-104
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ paimon-python/.idea/
2424
paimon-python/dist/
2525
paimon-python/*.egg-info/
2626
paimon-python/dev/log
27+
paimon-lucene/.idea/
2728

2829
### VS Code ###
2930
.vscode/

paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexEvaluator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ public Optional<GlobalIndexResult> evaluate(
6363
int fieldId = rowType.getField(vectorSearch.fieldName()).id();
6464
Collection<GlobalIndexReader> readers =
6565
indexReadersCache.computeIfAbsent(fieldId, readersFunction::apply);
66-
compoundResult.ifPresent(
67-
globalIndexResult ->
68-
vectorSearch.withIncludeRowIds(globalIndexResult.results().iterator()));
66+
if (compoundResult.isPresent()) {
67+
vectorSearch = vectorSearch.withIncludeRowIds(compoundResult.get().results());
68+
}
6969
for (GlobalIndexReader fileIndexReader : readers) {
7070
GlobalIndexResult childResult = vectorSearch.visit(fileIndexReader);
7171
// AND Operation

paimon-common/src/main/java/org/apache/paimon/predicate/VectorSearch.java

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,26 @@
2020

2121
import org.apache.paimon.globalindex.GlobalIndexReader;
2222
import org.apache.paimon.globalindex.GlobalIndexResult;
23+
import org.apache.paimon.utils.RoaringNavigableMap64;
2324

2425
import javax.annotation.Nullable;
2526

2627
import java.io.Serializable;
27-
import java.util.Iterator;
28-
import java.util.Optional;
2928

3029
/** VectorSearch to perform vector similarity search. * */
3130
public class VectorSearch implements Serializable {
31+
3232
private static final long serialVersionUID = 1L;
3333

34-
private Object search;
35-
private String fieldName;
36-
private Optional<String> similarityFunction;
37-
private int limit;
38-
private Iterator<Long> includeRowIds;
39-
40-
public VectorSearch(
41-
Object search,
42-
int limit,
43-
String fieldName,
44-
@Nullable Iterator<Long> includeRowIds,
45-
@Nullable String similarityFunction) {
46-
if (search == null) {
34+
// float[] or byte[]
35+
private final Object vector;
36+
private final String fieldName;
37+
private final int limit;
38+
39+
@Nullable private RoaringNavigableMap64 includeRowIds;
40+
41+
public VectorSearch(Object vector, int limit, String fieldName) {
42+
if (vector == null) {
4743
throw new IllegalArgumentException("Search cannot be null");
4844
}
4945
if (limit <= 0) {
@@ -52,23 +48,14 @@ public VectorSearch(
5248
if (fieldName == null || fieldName.isEmpty()) {
5349
throw new IllegalArgumentException("Field name cannot be null or empty");
5450
}
55-
this.search = search;
51+
this.vector = vector;
5652
this.limit = limit;
5753
this.fieldName = fieldName;
58-
this.similarityFunction = Optional.ofNullable(similarityFunction);
59-
this.includeRowIds = includeRowIds;
60-
}
61-
62-
public VectorSearch(Object search, int limit, String fieldName) {
63-
this(search, limit, fieldName, null, null);
64-
}
65-
66-
public VectorSearch(Object search, int limit, String fieldName, Iterator<Long> includeRowIds) {
67-
this(search, limit, fieldName, includeRowIds, null);
6854
}
6955

70-
public Object search() {
71-
return search;
56+
// float[] or byte[]
57+
public Object vector() {
58+
return vector;
7259
}
7360

7461
public int limit() {
@@ -79,15 +66,11 @@ public String fieldName() {
7966
return fieldName;
8067
}
8168

82-
public Optional<String> similarityFunction() {
83-
return similarityFunction;
84-
}
85-
86-
public Iterator<Long> includeRowIds() {
69+
public RoaringNavigableMap64 includeRowIds() {
8770
return includeRowIds;
8871
}
8972

90-
public VectorSearch withIncludeRowIds(Iterator<Long> includeRowIds) {
73+
public VectorSearch withIncludeRowIds(RoaringNavigableMap64 includeRowIds) {
9174
this.includeRowIds = includeRowIds;
9275
return this;
9376
}
@@ -98,8 +81,6 @@ public GlobalIndexResult visit(GlobalIndexReader visitor) {
9881

9982
@Override
10083
public String toString() {
101-
return String.format(
102-
"FieldName(%s), SimilarityFunction(%s), Limit(%s)",
103-
fieldName, similarityFunction, limit);
84+
return String.format("FieldName(%s), Limit(%s)", fieldName, limit);
10485
}
10586
}

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/LuceneVectorGlobalIndexReader.java

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.paimon.types.DataType;
3030
import org.apache.paimon.types.FloatType;
3131
import org.apache.paimon.types.TinyIntType;
32+
import org.apache.paimon.utils.IOUtils;
3233
import org.apache.paimon.utils.Range;
3334
import org.apache.paimon.utils.RoaringNavigableMap64;
3435

@@ -68,17 +69,14 @@ public class LuceneVectorGlobalIndexReader implements GlobalIndexReader {
6869
private final GlobalIndexFileReader fileReader;
6970
private final GlobalIndexResult defaultResult;
7071
private volatile boolean indicesLoaded = false;
71-
private final LuceneVectorIndexOptions vectorIndexOptions;
7272
private final DataType fieldType;
7373

7474
public LuceneVectorGlobalIndexReader(
7575
GlobalIndexFileReader fileReader,
7676
List<GlobalIndexIOMeta> ioMetas,
77-
LuceneVectorIndexOptions options,
7877
DataType fieldType) {
7978
this.fileReader = fileReader;
8079
this.ioMetas = ioMetas;
81-
this.vectorIndexOptions = options;
8280
this.fieldType = fieldType;
8381
this.searchers = new ArrayList<>();
8482
this.directories = new ArrayList<>();
@@ -88,23 +86,17 @@ public LuceneVectorGlobalIndexReader(
8886
@Override
8987
public GlobalIndexResult visitVectorSearch(VectorSearch vectorSearch) {
9088
try {
91-
if (vectorSearch.similarityFunction().isEmpty()
92-
|| LuceneVectorMetric.fromString(vectorSearch.similarityFunction().get())
93-
== vectorIndexOptions.metric()) {
94-
ensureLoadIndices(fileReader, ioMetas);
95-
Query query = query(vectorSearch, fieldType);
96-
return search(query, vectorSearch.limit());
97-
}
89+
ensureLoadIndices(fileReader, ioMetas);
90+
Query query = query(vectorSearch, fieldType);
91+
return search(query, vectorSearch.limit());
9892
} catch (IOException e) {
9993
throw new RuntimeException(
10094
String.format(
101-
"Failed to search vector index with fieldName=%s, similarity=%s, limit=%d",
95+
"Failed to search vector index with fieldName=%s, limit=%d",
10296
vectorSearch.fieldName(),
103-
vectorIndexOptions.metric(),
10497
vectorSearch.limit()),
10598
e);
10699
}
107-
return defaultResult;
108100
}
109101

110102
@Override
@@ -153,32 +145,35 @@ public void close() throws IOException {
153145

154146
private Query query(VectorSearch vectorSearch, DataType dataType) {
155147
Query idFilterQuery = null;
156-
Iterator<Long> includeRowIds = vectorSearch.includeRowIds();
148+
RoaringNavigableMap64 includeRowIds = vectorSearch.includeRowIds();
157149
if (includeRowIds != null) {
158-
ArrayList<Long> targetIds = new ArrayList<>();
159-
includeRowIds.forEachRemaining(id -> targetIds.add(id));
150+
long[] targetIds = new long[includeRowIds.getIntCardinality()];
151+
Iterator<Long> iterator = includeRowIds.iterator();
152+
for (int i = 0; i < targetIds.length; i++) {
153+
targetIds[i] = iterator.next();
154+
}
160155
idFilterQuery = LongPoint.newSetQuery(ROW_ID_FIELD, targetIds);
161156
}
162157
if (dataType instanceof ArrayType
163158
&& ((ArrayType) dataType).getElementType() instanceof FloatType) {
164-
if (!(vectorSearch.search() instanceof float[])) {
159+
if (!(vectorSearch.vector() instanceof float[])) {
165160
throw new IllegalArgumentException(
166-
"Expected float[] vector but got: " + vectorSearch.search().getClass());
161+
"Expected float[] vector but got: " + vectorSearch.vector().getClass());
167162
}
168163
return new KnnFloatVectorQuery(
169164
LuceneVectorIndex.VECTOR_FIELD,
170-
(float[]) vectorSearch.search(),
165+
(float[]) vectorSearch.vector(),
171166
vectorSearch.limit(),
172167
idFilterQuery);
173168
} else if (dataType instanceof ArrayType
174169
&& ((ArrayType) dataType).getElementType() instanceof TinyIntType) {
175-
if (!(vectorSearch.search() instanceof byte[])) {
170+
if (!(vectorSearch.vector() instanceof byte[])) {
176171
throw new IllegalArgumentException(
177-
"Expected byte[] vector but got: " + vectorSearch.search().getClass());
172+
"Expected byte[] vector but got: " + vectorSearch.vector().getClass());
178173
}
179174
return new KnnByteVectorQuery(
180175
LuceneVectorIndex.VECTOR_FIELD,
181-
(byte[]) vectorSearch.search(),
176+
(byte[]) vectorSearch.vector(),
182177
vectorSearch.limit(),
183178
idFilterQuery);
184179
} else {
@@ -249,19 +244,8 @@ private void ensureLoadIndices(GlobalIndexFileReader fileReader, List<GlobalInde
249244
indicesLoaded = true;
250245
} finally {
251246
if (!indicesLoaded) {
252-
if (reader != null) {
253-
try {
254-
reader.close();
255-
} catch (IOException e) {
256-
}
257-
}
258-
if (directory != null) {
259-
try {
260-
directory.close();
261-
} catch (Exception e) {
262-
throw new IOException("Failed to close directory", e);
263-
}
264-
}
247+
IOUtils.closeQuietly(reader);
248+
IOUtils.closeQuietly(directory);
265249
}
266250
}
267251
}

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/LuceneVectorGlobalIndexWriter.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public class LuceneVectorGlobalIndexWriter implements GlobalIndexWriter {
5252
private final LuceneVectorIndexFactory vectorIndexFactory;
5353

5454
private long count = 0;
55-
private final List<LuceneVectorIndex> vectorIndices;
55+
private final List<LuceneVectorIndex<?>> vectorIndices;
5656
private final List<ResultEntry> results;
5757

5858
public LuceneVectorGlobalIndexWriter(
@@ -70,7 +70,7 @@ public LuceneVectorGlobalIndexWriter(
7070

7171
@Override
7272
public void write(Object key) {
73-
LuceneVectorIndex index = vectorIndexFactory.create(count, key);
73+
LuceneVectorIndex<?> index = vectorIndexFactory.create(count, key);
7474
index.checkDimension(vectorIndexOptions.dimension());
7575
vectorIndices.add(index);
7676
if (vectorIndices.size() >= sizePerIndex) {
@@ -113,7 +113,7 @@ private void flush() throws IOException {
113113
}
114114

115115
private void buildIndex(
116-
List<LuceneVectorIndex> batchVectors,
116+
List<LuceneVectorIndex<?>> batchVectors,
117117
int m,
118118
int efConstruction,
119119
int writeBufferSize,
@@ -123,7 +123,7 @@ private void buildIndex(
123123
try (LuceneIndexMMapDirectory luceneIndexMMapDirectory = new LuceneIndexMMapDirectory()) {
124124
try (IndexWriter writer =
125125
new IndexWriter(luceneIndexMMapDirectory.directory(), config)) {
126-
for (LuceneVectorIndex luceneVectorIndex : batchVectors) {
126+
for (LuceneVectorIndex<?> luceneVectorIndex : batchVectors) {
127127
Document doc = new Document();
128128
doc.add(luceneVectorIndex.indexableField(similarityFunction));
129129
doc.add(luceneVectorIndex.rowIdLongPoint());

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/LuceneVectorGlobalIndexer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public GlobalIndexWriter createWriter(GlobalIndexFileWriter fileWriter) {
4848

4949
@Override
5050
public GlobalIndexReader createReader(
51-
GlobalIndexFileReader fileReader, List<GlobalIndexIOMeta> files) throws IOException {
52-
return new LuceneVectorGlobalIndexReader(fileReader, files, options, fieldType);
51+
GlobalIndexFileReader fileReader, List<GlobalIndexIOMeta> files) {
52+
return new LuceneVectorGlobalIndexReader(fileReader, files, fieldType);
5353
}
5454
}

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/LuceneVectorIndexFactory.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,20 @@ public static LuceneVectorIndexFactory init(DataType dataType) {
3838
}
3939
}
4040

41-
public abstract LuceneVectorIndex create(long rowId, Object vector);
41+
public abstract LuceneVectorIndex<?> create(long rowId, Object vector);
4242

4343
/** Factory for creating LuceneFloatVectorIndex instances. */
4444
public static class LuceneFloatVectorIndexFactory extends LuceneVectorIndexFactory {
4545
@Override
46-
public LuceneVectorIndex create(long rowId, Object vector) {
46+
public LuceneVectorIndex<?> create(long rowId, Object vector) {
4747
return new LuceneFloatVectorIndex(rowId, (float[]) vector);
4848
}
4949
}
5050

5151
/** Factory for creating LuceneByteVectorIndex instances. */
5252
public static class LuceneByteVectorIndexFactory extends LuceneVectorIndexFactory {
5353
@Override
54-
public LuceneVectorIndex create(long rowId, Object vector) {
54+
public LuceneVectorIndex<?> create(long rowId, Object vector) {
5555
return new LuceneByteVectorIndex(rowId, (byte[]) vector);
5656
}
5757
}

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/LuceneVectorSearchGlobalIndexResult.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public LuceneVectorSearchGlobalIndexResult(
3838

3939
@Override
4040
public ScoreGetter scoreGetter() {
41-
return rowId -> id2scores.get(rowId);
41+
return id2scores::get;
4242
}
4343

4444
@Override

paimon-lucene/src/test/java/org/apache/paimon/lucene/index/LuceneVectorGlobalIndexScanTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,17 @@ public class LuceneVectorGlobalIndexScanTest {
6868

6969
private FileStoreTable table;
7070
private String commitUser;
71-
private Path tablePath;
7271
private FileIO fileIO;
7372
private RowType rowType;
74-
private String similarityMetric = "EUCLIDEAN";
75-
private String vectorFieldName = "vec";
73+
private final String vectorFieldName = "vec";
7674

7775
@BeforeEach
7876
public void before() throws Exception {
79-
tablePath = new Path(tempDir.toString());
77+
Path tablePath = new Path(tempDir.toString());
8078
fileIO = new LocalFileIO();
8179
SchemaManager schemaManager = new SchemaManager(fileIO, tablePath);
8280

81+
String similarityMetric = "EUCLIDEAN";
8382
Schema schema =
8483
Schema.newBuilder()
8584
.column("id", DataTypes.INT())

0 commit comments

Comments
 (0)