Skip to content

Commit 3d41be6

Browse files
authored
[vector] support vector search (#6807)
1 parent 21f6a44 commit 3d41be6

25 files changed

+760
-152
lines changed

paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexEvaluator.java

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.paimon.predicate.Predicate;
2626
import org.apache.paimon.predicate.PredicateVisitor;
2727
import org.apache.paimon.predicate.TransformPredicate;
28+
import org.apache.paimon.predicate.VectorSearch;
2829
import org.apache.paimon.types.RowType;
2930
import org.apache.paimon.utils.IOUtils;
3031

@@ -52,18 +53,35 @@ public GlobalIndexEvaluator(
5253
this.readersFunction = readersFunction;
5354
}
5455

55-
public Optional<GlobalIndexResult> evaluate(@Nullable Predicate predicate) {
56-
if (predicate == null) {
57-
return Optional.empty();
56+
public Optional<GlobalIndexResult> evaluate(
57+
@Nullable Predicate predicate, @Nullable VectorSearch vectorSearch) {
58+
Optional<GlobalIndexResult> compoundResult = Optional.empty();
59+
if (predicate != null) {
60+
compoundResult = predicate.visit(this);
5861
}
59-
return predicate.visit(this);
60-
}
62+
if (vectorSearch != null) {
63+
int fieldId = rowType.getField(vectorSearch.fieldName()).id();
64+
Collection<GlobalIndexReader> readers =
65+
indexReadersCache.computeIfAbsent(fieldId, readersFunction::apply);
66+
compoundResult.ifPresent(
67+
globalIndexResult ->
68+
vectorSearch.withIncludeRowIds(globalIndexResult.results().iterator()));
69+
for (GlobalIndexReader fileIndexReader : readers) {
70+
GlobalIndexResult childResult = vectorSearch.visit(fileIndexReader);
71+
// AND Operation
72+
if (compoundResult.isPresent()) {
73+
GlobalIndexResult r1 = compoundResult.get();
74+
compoundResult = Optional.of(r1.and(childResult));
75+
} else {
76+
compoundResult = Optional.of(childResult);
77+
}
6178

62-
public void close() {
63-
IOUtils.closeAllQuietly(
64-
indexReadersCache.values().stream()
65-
.flatMap(Collection::stream)
66-
.collect(Collectors.toList()));
79+
if (compoundResult.get().results().isEmpty()) {
80+
return compoundResult;
81+
}
82+
}
83+
}
84+
return compoundResult;
6785
}
6886

6987
@Override
@@ -135,4 +153,11 @@ public Optional<GlobalIndexResult> visit(CompoundPredicate predicate) {
135153
public Optional<GlobalIndexResult> visit(TransformPredicate predicate) {
136154
return Optional.empty();
137155
}
156+
157+
public void close() {
158+
IOUtils.closeAllQuietly(
159+
indexReadersCache.values().stream()
160+
.flatMap(Collection::stream)
161+
.collect(Collectors.toList()));
162+
}
138163
}

paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexReader.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.apache.paimon.predicate.FunctionVisitor;
2222
import org.apache.paimon.predicate.TransformPredicate;
23+
import org.apache.paimon.predicate.VectorSearch;
2324

2425
import java.io.Closeable;
2526
import java.util.List;
@@ -41,4 +42,8 @@ default GlobalIndexResult visitOr(List<GlobalIndexResult> children) {
4142
default GlobalIndexResult visit(TransformPredicate predicate) {
4243
throw new UnsupportedOperationException("Should not invoke this");
4344
}
45+
46+
default GlobalIndexResult visitVectorSearch(VectorSearch vectorSearch) {
47+
throw new UnsupportedOperationException("Should not invoke this");
48+
}
4449
}

paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,11 @@ public void serialize(GlobalIndexResult globalIndexResult, DataOutputView dataOu
6666
dataOutput.writeInt(bytes.length);
6767
dataOutput.write(bytes);
6868

69-
if (globalIndexResult instanceof TopkGlobalIndexResult) {
70-
TopkGlobalIndexResult topkGlobalIndexResult = (TopkGlobalIndexResult) globalIndexResult;
69+
if (globalIndexResult instanceof VectorSearchGlobalIndexResult) {
70+
VectorSearchGlobalIndexResult vectorSearchGlobalIndexResult =
71+
(VectorSearchGlobalIndexResult) globalIndexResult;
7172
dataOutput.writeInt(roaringNavigableMap64.getIntCardinality());
72-
ScoreGetter scoreGetter = topkGlobalIndexResult.scoreGetter();
73+
ScoreGetter scoreGetter = vectorSearchGlobalIndexResult.scoreGetter();
7374
for (Long rowId : roaringNavigableMap64) {
7475
dataOutput.writeFloat(scoreGetter.score(rowId));
7576
}
@@ -114,6 +115,6 @@ public GlobalIndexResult deserialize(DataInputView dataInput) throws IOException
114115
scoreMap.put(rowId, scores[i++]);
115116
}
116117

117-
return TopkGlobalIndexResult.create(() -> roaringNavigableMap64, scoreMap::get);
118+
return VectorSearchGlobalIndexResult.create(() -> roaringNavigableMap64, scoreMap::get);
118119
}
119120
}

paimon-common/src/main/java/org/apache/paimon/globalindex/OffsetGlobalIndexReader.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.paimon.globalindex;
2020

2121
import org.apache.paimon.predicate.FieldRef;
22+
import org.apache.paimon.predicate.VectorSearch;
2223

2324
import java.io.IOException;
2425
import java.util.List;
@@ -107,6 +108,11 @@ public GlobalIndexResult visitNotIn(FieldRef fieldRef, List<Object> literals) {
107108
return applyOffset(wrapped.visitNotIn(fieldRef, literals));
108109
}
109110

111+
@Override
112+
public GlobalIndexResult visitVectorSearch(VectorSearch vectorSearch) {
113+
return applyOffset(wrapped.visitVectorSearch(vectorSearch));
114+
}
115+
110116
private GlobalIndexResult applyOffset(GlobalIndexResult result) {
111117
if (result == null) {
112118
throw new IllegalStateException("Wrapped reader should not return null");

paimon-common/src/main/java/org/apache/paimon/globalindex/UnionGlobalIndexReader.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.paimon.globalindex;
2020

2121
import org.apache.paimon.predicate.FieldRef;
22+
import org.apache.paimon.predicate.VectorSearch;
2223

2324
import java.io.IOException;
2425
import java.util.List;
@@ -106,6 +107,11 @@ public GlobalIndexResult visitNotIn(FieldRef fieldRef, List<Object> literals) {
106107
return union(reader -> reader.visitNotIn(fieldRef, literals));
107108
}
108109

110+
@Override
111+
public GlobalIndexResult visitVectorSearch(VectorSearch vectorSearch) {
112+
return union(reader -> reader.visitVectorSearch(vectorSearch));
113+
}
114+
109115
private GlobalIndexResult union(Function<GlobalIndexReader, GlobalIndexResult> visitor) {
110116
GlobalIndexResult result = null;
111117
for (GlobalIndexReader reader : readers) {

paimon-common/src/main/java/org/apache/paimon/globalindex/TopkGlobalIndexResult.java renamed to paimon-common/src/main/java/org/apache/paimon/globalindex/VectorSearchGlobalIndexResult.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323

2424
import java.util.function.Supplier;
2525

26-
/** Top-k global index result for vector index. */
27-
public interface TopkGlobalIndexResult extends GlobalIndexResult {
26+
/** Vector search global index result for vector index. */
27+
public interface VectorSearchGlobalIndexResult extends GlobalIndexResult {
2828

2929
ScoreGetter scoreGetter();
3030

3131
default GlobalIndexResult and(GlobalIndexResult other) {
3232
throw new UnsupportedOperationException("Please realize this by specified global index");
3333
}
3434

35-
default TopkGlobalIndexResult offset(long offset) {
35+
default VectorSearchGlobalIndexResult offset(long offset) {
3636
if (offset == 0) {
3737
return this;
3838
}
@@ -51,17 +51,17 @@ default TopkGlobalIndexResult offset(long offset) {
5151

5252
@Override
5353
default GlobalIndexResult or(GlobalIndexResult other) {
54-
if (!(other instanceof TopkGlobalIndexResult)) {
54+
if (!(other instanceof VectorSearchGlobalIndexResult)) {
5555
return GlobalIndexResult.super.or(other);
5656
}
5757
RoaringNavigableMap64 thisRowIds = results();
5858
ScoreGetter thisScoreGetter = scoreGetter();
5959

6060
RoaringNavigableMap64 otherRowIds = other.results();
61-
ScoreGetter otherScoreGetter = ((TopkGlobalIndexResult) other).scoreGetter();
61+
ScoreGetter otherScoreGetter = ((VectorSearchGlobalIndexResult) other).scoreGetter();
6262

6363
final RoaringNavigableMap64 resultOr = RoaringNavigableMap64.or(thisRowIds, otherRowIds);
64-
return new TopkGlobalIndexResult() {
64+
return new VectorSearchGlobalIndexResult() {
6565
@Override
6666
public ScoreGetter scoreGetter() {
6767
return rowId -> {
@@ -79,11 +79,11 @@ public RoaringNavigableMap64 results() {
7979
};
8080
}
8181

82-
/** Returns a new {@link TopkGlobalIndexResult} from supplier. */
83-
static TopkGlobalIndexResult create(
82+
/** Returns a new {@link VectorSearchGlobalIndexResult} from supplier. */
83+
static VectorSearchGlobalIndexResult create(
8484
Supplier<RoaringNavigableMap64> supplier, ScoreGetter scoreGetter) {
8585
LazyField<RoaringNavigableMap64> lazyField = new LazyField<>(supplier);
86-
return new TopkGlobalIndexResult() {
86+
return new VectorSearchGlobalIndexResult() {
8787
@Override
8888
public ScoreGetter scoreGetter() {
8989
return scoreGetter;
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.paimon.predicate;
20+
21+
import org.apache.paimon.globalindex.GlobalIndexReader;
22+
import org.apache.paimon.globalindex.GlobalIndexResult;
23+
24+
import javax.annotation.Nullable;
25+
26+
import java.io.Serializable;
27+
import java.util.Iterator;
28+
import java.util.Optional;
29+
30+
/** VectorSearch to perform vector similarity search. * */
31+
public class VectorSearch implements Serializable {
32+
private static final long serialVersionUID = 1L;
33+
34+
private Object search;
35+
private String fieldName;
36+
private Optional<String> similarityFunction;
37+
private int limit;
38+
private Iterator<Long> includeRowIds;
39+
40+
public VectorSearch(
41+
Object search,
42+
int limit,
43+
String fieldName,
44+
@Nullable Iterator<Long> includeRowIds,
45+
@Nullable String similarityFunction) {
46+
if (search == null) {
47+
throw new IllegalArgumentException("Search cannot be null");
48+
}
49+
if (limit <= 0) {
50+
throw new IllegalArgumentException("Limit must be positive, got: " + limit);
51+
}
52+
if (fieldName == null || fieldName.isEmpty()) {
53+
throw new IllegalArgumentException("Field name cannot be null or empty");
54+
}
55+
this.search = search;
56+
this.limit = limit;
57+
this.fieldName = fieldName;
58+
this.similarityFunction = Optional.ofNullable(similarityFunction);
59+
this.includeRowIds = includeRowIds;
60+
}
61+
62+
public VectorSearch(Object search, int limit, String fieldName) {
63+
this(search, limit, fieldName, null, null);
64+
}
65+
66+
public VectorSearch(Object search, int limit, String fieldName, Iterator<Long> includeRowIds) {
67+
this(search, limit, fieldName, includeRowIds, null);
68+
}
69+
70+
public Object search() {
71+
return search;
72+
}
73+
74+
public int limit() {
75+
return limit;
76+
}
77+
78+
public String fieldName() {
79+
return fieldName;
80+
}
81+
82+
public Optional<String> similarityFunction() {
83+
return similarityFunction;
84+
}
85+
86+
public Iterator<Long> includeRowIds() {
87+
return includeRowIds;
88+
}
89+
90+
public VectorSearch withIncludeRowIds(Iterator<Long> includeRowIds) {
91+
this.includeRowIds = includeRowIds;
92+
return this;
93+
}
94+
95+
public GlobalIndexResult visit(GlobalIndexReader visitor) {
96+
return visitor.visitVectorSearch(this);
97+
}
98+
99+
@Override
100+
public String toString() {
101+
return String.format(
102+
"FieldName(%s), SimilarityFunction(%s), Limit(%s)",
103+
fieldName, similarityFunction, limit);
104+
}
105+
}

paimon-common/src/test/java/org/apache/paimon/globalindex/GlobalIndexSerDeUtilsTest.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public void testSerializeAndDeserializeGlobalIndexResult() throws IOException {
4141
byte[] serialized = serialize(original);
4242
GlobalIndexResult deserialized = deserialize(serialized);
4343

44-
assertThat(deserialized).isNotInstanceOf(TopkGlobalIndexResult.class);
44+
assertThat(deserialized).isNotInstanceOf(VectorSearchGlobalIndexResult.class);
4545
assertThat(deserialized.results()).isEqualTo(bitmap);
4646
}
4747

@@ -52,7 +52,7 @@ public void testSerializeAndDeserializeEmptyGlobalIndexResult() throws IOExcepti
5252
byte[] serialized = serialize(original);
5353
GlobalIndexResult deserialized = deserialize(serialized);
5454

55-
assertThat(deserialized).isNotInstanceOf(TopkGlobalIndexResult.class);
55+
assertThat(deserialized).isNotInstanceOf(VectorSearchGlobalIndexResult.class);
5656
assertThat(deserialized.results().isEmpty()).isTrue();
5757
}
5858

@@ -65,15 +65,16 @@ public void testSerializeAndDeserializeTopkGlobalIndexResult() throws IOExceptio
6565
scoreMap.put(10L, 0.7f);
6666
scoreMap.put(100L, 0.6f);
6767

68-
TopkGlobalIndexResult original = TopkGlobalIndexResult.create(() -> bitmap, scoreMap::get);
68+
VectorSearchGlobalIndexResult original =
69+
VectorSearchGlobalIndexResult.create(() -> bitmap, scoreMap::get);
6970

7071
byte[] serialized = serialize(original);
7172
GlobalIndexResult deserialized = deserialize(serialized);
7273

73-
assertThat(deserialized).isInstanceOf(TopkGlobalIndexResult.class);
74+
assertThat(deserialized).isInstanceOf(VectorSearchGlobalIndexResult.class);
7475
assertThat(deserialized.results()).isEqualTo(bitmap);
7576

76-
TopkGlobalIndexResult topkResult = (TopkGlobalIndexResult) deserialized;
77+
VectorSearchGlobalIndexResult topkResult = (VectorSearchGlobalIndexResult) deserialized;
7778
ScoreGetter scoreGetter = topkResult.scoreGetter();
7879
assertThat(scoreGetter.score(1L)).isEqualTo(0.9f);
7980
assertThat(scoreGetter.score(5L)).isEqualTo(0.8f);
@@ -91,15 +92,16 @@ public void testSerializeAndDeserializeTopkWithLargeRowIds() throws IOException
9192
scoreMap.put(Integer.MAX_VALUE + 100L, 0.3f);
9293
scoreMap.put(Long.MAX_VALUE - 1, 0.1f);
9394

94-
TopkGlobalIndexResult original = TopkGlobalIndexResult.create(() -> bitmap, scoreMap::get);
95+
VectorSearchGlobalIndexResult original =
96+
VectorSearchGlobalIndexResult.create(() -> bitmap, scoreMap::get);
9597

9698
byte[] serialized = serialize(original);
9799
GlobalIndexResult deserialized = deserialize(serialized);
98100

99-
assertThat(deserialized).isInstanceOf(TopkGlobalIndexResult.class);
101+
assertThat(deserialized).isInstanceOf(VectorSearchGlobalIndexResult.class);
100102
assertThat(deserialized.results()).isEqualTo(bitmap);
101103

102-
TopkGlobalIndexResult topkResult = (TopkGlobalIndexResult) deserialized;
104+
VectorSearchGlobalIndexResult topkResult = (VectorSearchGlobalIndexResult) deserialized;
103105
ScoreGetter scoreGetter = topkResult.scoreGetter();
104106
assertThat(scoreGetter.score(Integer.MAX_VALUE + 1L)).isEqualTo(0.5f);
105107
assertThat(scoreGetter.score(Integer.MAX_VALUE + 100L)).isEqualTo(0.3f);

0 commit comments

Comments
 (0)