Skip to content

Commit 768c5af

Browse files
committed
SOLR-17164: Add 2 arg variant of vectorSimilarity() function
1 parent f67b718 commit 768c5af

File tree

7 files changed

+622
-53
lines changed

7 files changed

+622
-53
lines changed

solr/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ Improvements
120120
* SOLR-17172: Add QueryLimits termination to the existing heavy SearchComponent-s. This allows query limits (e.g. timeAllowed,
121121
cpuAllowed) to terminate expensive operations within components if limits are exceeded. (Andrzej Bialecki)
122122

123+
* SOLR-17164: Add 2 arg variant of vectorSimilarity() function (Sanjay Dutt, hossman)
124+
123125
Optimizations
124126
---------------------
125127
* SOLR-17144: Close searcherExecutor thread per core after 1 minute (Pierre Salagnac, Christine Poerschke)

solr/core/src/java/org/apache/solr/search/ValueSourceParser.java

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,19 @@
2626
import java.util.Map;
2727
import org.apache.lucene.index.LeafReaderContext;
2828
import org.apache.lucene.index.Term;
29-
import org.apache.lucene.index.VectorEncoding;
30-
import org.apache.lucene.index.VectorSimilarityFunction;
3129
import org.apache.lucene.queries.function.FunctionScoreQuery;
3230
import org.apache.lucene.queries.function.FunctionValues;
3331
import org.apache.lucene.queries.function.ValueSource;
3432
import org.apache.lucene.queries.function.docvalues.BoolDocValues;
3533
import org.apache.lucene.queries.function.docvalues.DoubleDocValues;
3634
import org.apache.lucene.queries.function.docvalues.LongDocValues;
37-
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
3835
import org.apache.lucene.queries.function.valuesource.ConstNumberSource;
3936
import org.apache.lucene.queries.function.valuesource.ConstValueSource;
4037
import org.apache.lucene.queries.function.valuesource.DefFunction;
4138
import org.apache.lucene.queries.function.valuesource.DivFloatFunction;
4239
import org.apache.lucene.queries.function.valuesource.DocFreqValueSource;
4340
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
4441
import org.apache.lucene.queries.function.valuesource.DualFloatFunction;
45-
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
4642
import org.apache.lucene.queries.function.valuesource.IDFValueSource;
4743
import org.apache.lucene.queries.function.valuesource.IfFunction;
4844
import org.apache.lucene.queries.function.valuesource.JoinDocFreqValueSource;
@@ -344,41 +340,7 @@ public ValueSource parse(FunctionQParser fp) throws SyntaxError {
344340
}
345341
});
346342
alias("sum", "add");
347-
addParser(
348-
"vectorSimilarity",
349-
new ValueSourceParser() {
350-
@Override
351-
public ValueSource parse(FunctionQParser fp) throws SyntaxError {
352-
353-
VectorEncoding vectorEncoding = VectorEncoding.valueOf(fp.parseArg());
354-
VectorSimilarityFunction functionName = VectorSimilarityFunction.valueOf(fp.parseArg());
355-
356-
int vectorEncodingFlag =
357-
vectorEncoding.equals(VectorEncoding.BYTE)
358-
? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
359-
: 0;
360-
ValueSource v1 =
361-
fp.parseValueSource(
362-
FunctionQParser.FLAG_DEFAULT
363-
| FunctionQParser.FLAG_CONSUME_DELIMITER
364-
| vectorEncodingFlag);
365-
ValueSource v2 =
366-
fp.parseValueSource(
367-
FunctionQParser.FLAG_DEFAULT
368-
| FunctionQParser.FLAG_CONSUME_DELIMITER
369-
| vectorEncodingFlag);
370-
371-
switch (vectorEncoding) {
372-
case FLOAT32:
373-
return new FloatVectorSimilarityFunction(functionName, v1, v2);
374-
case BYTE:
375-
return new ByteVectorSimilarityFunction(functionName, v1, v2);
376-
default:
377-
throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
378-
}
379-
}
380-
});
381-
343+
addParser("vectorSimilarity", new VectorSimilaritySourceParser());
382344
addParser(
383345
"product",
384346
new ValueSourceParser() {
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.solr.search;
18+
19+
import static org.apache.solr.common.SolrException.ErrorCode;
20+
import static org.apache.solr.common.SolrException.ErrorCode.BAD_REQUEST;
21+
22+
import java.util.Arrays;
23+
import java.util.Locale;
24+
import org.apache.lucene.index.VectorEncoding;
25+
import org.apache.lucene.index.VectorSimilarityFunction;
26+
import org.apache.lucene.queries.function.ValueSource;
27+
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
28+
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
29+
import org.apache.solr.common.SolrException;
30+
import org.apache.solr.schema.DenseVectorField;
31+
import org.apache.solr.schema.FieldType;
32+
import org.apache.solr.schema.SchemaField;
33+
34+
/**
35+
* This class provides implementation for two variants for parsing function query vectorSimilarity
36+
* which is used to calculate the similarity between two vectors.
37+
*/
38+
public class VectorSimilaritySourceParser extends ValueSourceParser {
39+
@Override
40+
public ValueSource parse(FunctionQParser fp) throws SyntaxError {
41+
42+
final String arg1Str = fp.parseArg();
43+
if (arg1Str == null || !fp.hasMoreArguments())
44+
throw new SolrException(
45+
BAD_REQUEST, "Invalid number of arguments. Please provide either two or four arguments.");
46+
47+
final String arg2Str = peekIsConstVector(fp) ? null : fp.parseArg();
48+
if (fp.hasMoreArguments() && arg2Str != null) {
49+
return handle4ArgsVariant(fp, arg1Str, arg2Str);
50+
}
51+
return handle2ArgsVariant(fp, arg1Str, arg2Str);
52+
}
53+
54+
/**
55+
* returns true if and only if the next argument is a constant vector, taking into consideration
56+
* that the next (literal) argument may be a param reference
57+
*/
58+
private boolean peekIsConstVector(final FunctionQParser fp) throws SyntaxError {
59+
final char rawPeek = fp.sp.peek();
60+
if ('[' == rawPeek) {
61+
return true;
62+
}
63+
if ('$' == rawPeek) {
64+
final int savedPos = fp.sp.pos;
65+
try {
66+
final String rawParam = fp.parseArg();
67+
return ((null != rawParam) && ('[' == (new StrParser(rawParam)).peek()));
68+
} finally {
69+
fp.sp.pos = savedPos;
70+
}
71+
}
72+
return false;
73+
}
74+
75+
private static int buildVectorEncodingFlag(final VectorEncoding vectorEncoding) {
76+
return FunctionQParser.FLAG_DEFAULT
77+
| FunctionQParser.FLAG_CONSUME_DELIMITER
78+
| (vectorEncoding.equals(VectorEncoding.BYTE)
79+
? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
80+
: 0);
81+
}
82+
83+
/** Expects to find args #3 and #4 (two vector ValueSources) still in the function parser */
84+
private ValueSource handle4ArgsVariant(FunctionQParser fp, String vecEncStr, String vecSimFuncStr)
85+
throws SyntaxError {
86+
final var vectorEncoding = enumValueOrBadRequest(VectorEncoding.class, vecEncStr);
87+
final var vectorSimilarityFunction =
88+
enumValueOrBadRequest(VectorSimilarityFunction.class, vecSimFuncStr);
89+
final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding);
90+
final ValueSource v1 = fp.parseValueSource(vectorEncodingFlag);
91+
final ValueSource v2 = fp.parseValueSource(vectorEncodingFlag);
92+
return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2);
93+
}
94+
95+
/**
96+
* If <code>field2Name</code> is null, then expects to find a constant vector as the only
97+
* remaining arg in the function parser.
98+
*/
99+
private ValueSource handle2ArgsVariant(FunctionQParser fp, String field1Name, String field2Name)
100+
throws SyntaxError {
101+
102+
final SchemaField field1 = fp.req.getSchema().getField(field1Name);
103+
final DenseVectorField field1Type = requireVectorType(field1);
104+
105+
final var vectorEncoding = field1Type.getVectorEncoding();
106+
final var vectorSimilarityFunction = field1Type.getSimilarityFunction();
107+
108+
final ValueSource v1 = field1Type.getValueSource(field1, fp);
109+
final ValueSource v2;
110+
111+
if (null == field2Name) {
112+
final int vectorEncodingFlag = buildVectorEncodingFlag(vectorEncoding);
113+
v2 = fp.parseValueSource(vectorEncodingFlag);
114+
115+
} else {
116+
final SchemaField field2 = fp.req.getSchema().getField(field2Name);
117+
final DenseVectorField field2Type = requireVectorType(field2);
118+
if (vectorEncoding != field2Type.getVectorEncoding()
119+
|| vectorSimilarityFunction != field2Type.getSimilarityFunction()) {
120+
throw new SolrException(
121+
BAD_REQUEST,
122+
String.format(
123+
Locale.ROOT,
124+
"Invalid arguments: vector field %s and vector field %s must have the same vectorEncoding and similarityFunction",
125+
field1.getName(),
126+
field2.getName()));
127+
}
128+
v2 = field2Type.getValueSource(field2, fp);
129+
}
130+
return createSimilarityFunction(vectorSimilarityFunction, vectorEncoding, v1, v2);
131+
}
132+
133+
private ValueSource createSimilarityFunction(
134+
VectorSimilarityFunction functionName,
135+
VectorEncoding vectorEncoding,
136+
ValueSource v1,
137+
ValueSource v2)
138+
throws SyntaxError {
139+
switch (vectorEncoding) {
140+
case FLOAT32:
141+
return new FloatVectorSimilarityFunction(functionName, v1, v2);
142+
case BYTE:
143+
return new ByteVectorSimilarityFunction(functionName, v1, v2);
144+
default:
145+
throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
146+
}
147+
}
148+
149+
private DenseVectorField requireVectorType(final SchemaField field) throws SyntaxError {
150+
final FieldType fieldType = field.getType();
151+
if (fieldType instanceof DenseVectorField) {
152+
return (DenseVectorField) field.getType();
153+
}
154+
throw new SolrException(
155+
BAD_REQUEST,
156+
String.format(
157+
Locale.ROOT,
158+
"Type mismatch: Expected [%s], but found a different field type for field: [%s]",
159+
DenseVectorField.class.getSimpleName(),
160+
field.getName()));
161+
}
162+
163+
/**
164+
* Helper method that returns the correct Enum instance for the <code>arg</code> String, or throws
165+
* a {@link ErrorCode#BAD_REQUEST} with specifics on the "Invalid argument"
166+
*/
167+
private static <T extends Enum<T>> T enumValueOrBadRequest(
168+
final Class<T> enumClass, final String arg) throws SolrException {
169+
assert null != enumClass;
170+
try {
171+
return Enum.valueOf(enumClass, arg);
172+
} catch (IllegalArgumentException | NullPointerException e) {
173+
throw new SolrException(
174+
BAD_REQUEST,
175+
String.format(
176+
Locale.ROOT,
177+
"Invalid argument: %s is not a valid %s. Expected one of %s",
178+
arg,
179+
enumClass.getSimpleName(),
180+
Arrays.toString(enumClass.getEnumConstants())));
181+
}
182+
}
183+
}

solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -911,13 +911,67 @@ public void testFuncVector() throws Exception {
911911
}
912912

913913
public void testFuncKnnVector() throws Exception {
914-
assertFuncEquals(
915-
"vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
916-
"vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])");
914+
try (SolrQueryRequest req =
915+
req(
916+
"v1", "[1,2,3]",
917+
"v2", " [1,2,3] ",
918+
"v3", " [1, 2, 3] ")) {
919+
assertFuncEquals(
920+
req,
921+
"vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
922+
"vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])",
923+
"vectorSimilarity(FLOAT32, COSINE,$v1, [4, 5, 6])",
924+
"vectorSimilarity(FLOAT32, COSINE, $v2 , [4, 5, 6])",
925+
"vectorSimilarity(FLOAT32, COSINE, $v3 , [4, 5, 6])");
926+
}
917927

918-
assertFuncEquals(
919-
"vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
920-
"vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])");
928+
try (SolrQueryRequest req =
929+
req(
930+
"f1", "bar_i",
931+
"f2", " bar_i ",
932+
"f3", " field(bar_i) ")) {
933+
assertFuncEquals(
934+
req,
935+
"vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
936+
"vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])",
937+
"vectorSimilarity(BYTE, EUCLIDEAN,$f1, [4, 5, 6])",
938+
"vectorSimilarity(BYTE, EUCLIDEAN, $f1, [4, 5, 6])",
939+
"vectorSimilarity(BYTE, EUCLIDEAN, $f2, [4, 5, 6])",
940+
"vectorSimilarity(BYTE, EUCLIDEAN, $f3, [4, 5, 6])");
941+
}
942+
943+
try (SolrQueryRequest req =
944+
req(
945+
"f", "vector",
946+
"v1", "[1,2,3,4]",
947+
"v2", " [1, 2, 3, 4]")) {
948+
assertFuncEquals(
949+
req,
950+
"vectorSimilarity(FLOAT32,COSINE,vector,[1,2,3,4])",
951+
"vectorSimilarity(FLOAT32,COSINE,vector,$v1)",
952+
"vectorSimilarity(FLOAT32,COSINE,vector, $v1)",
953+
"vectorSimilarity(FLOAT32,COSINE,vector,$v2)",
954+
"vectorSimilarity(FLOAT32,COSINE,vector, $v2)",
955+
"vectorSimilarity(vector,[1,2,3,4])",
956+
"vectorSimilarity( vector,[1,2,3,4])",
957+
"vectorSimilarity( $f,[1,2,3,4])",
958+
"vectorSimilarity(vector,$v1)",
959+
"vectorSimilarity(vector, $v1)",
960+
"vectorSimilarity( $f, $v1)",
961+
"vectorSimilarity(vector,$v2)",
962+
"vectorSimilarity(vector, $v2)");
963+
}
964+
965+
// contrived, but helps us test the param resolution
966+
// for both field names in the 2arg usecase
967+
try (SolrQueryRequest req = req("f", "vector")) {
968+
assertFuncEquals(
969+
req,
970+
"vectorSimilarity($f, $f)",
971+
"vectorSimilarity($f, vector)",
972+
"vectorSimilarity(vector, $f)",
973+
"vectorSimilarity(vector, vector)");
974+
}
921975
}
922976

923977
public void testFuncQuery() throws Exception {

0 commit comments

Comments
 (0)