Skip to content

Commit da7230a

Browse files
committed
SOLR-17369: Fix "flags" usage in FunctionQParser that caused some issues in vectorSimilarity() with BYTE vector constants
1 parent d567066 commit da7230a

File tree

5 files changed

+75
-49
lines changed

5 files changed

+75
-49
lines changed

solr/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ Bug Fixes
206206

207207
* SOLR-17367: Restore the use of -params option to PostTool. (Bostoi via Eric Pugh)
208208

209+
* SOLR-17369: Fix "flags" usage in FunctionQParser that caused some issues in vectorSimilarity() with BYTE vector constants (hossman)
210+
209211
Dependency Upgrades
210212
---------------------
211213
(No changes)

solr/core/src/java/org/apache/solr/search/FunctionQParser.java

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ public class FunctionQParser extends QParser {
5959
public FunctionQParser(
6060
String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
6161
super(qstr, localParams, params, req);
62+
setFlags(FLAG_DEFAULT);
6263
setString(qstr);
6364
}
6465

@@ -89,15 +90,12 @@ public boolean getParseToEnd() {
8990
}
9091

9192
@Override
92-
@SuppressWarnings("ErroneousBitwiseExpression")
9393
public Query parse() throws SyntaxError {
9494
ValueSource vs = null;
9595
List<ValueSource> lst = null;
9696

9797
for (; ; ) {
98-
// @SuppressWarnings("ErroneousBitwiseExpression") is needed since
99-
// FLAG_DEFAULT & ~FLAG_CONSUME_DELIMITER == 0
100-
ValueSource valsource = parseValueSource(FLAG_DEFAULT & ~FLAG_CONSUME_DELIMITER);
98+
ValueSource valsource = parseValueSource(getFlags() & ~FLAG_CONSUME_DELIMITER);
10199
sp.eatws();
102100
if (!parseMultipleSources) {
103101
vs = valsource;
@@ -298,7 +296,7 @@ public List<Number> parseVector(VectorEncoding encoding) throws SyntaxError {
298296
* @return List&lt;ValueSource&gt;
299297
*/
300298
public List<ValueSource> parseValueSourceList() throws SyntaxError {
301-
return parseValueSourceList(FLAG_DEFAULT | FLAG_CONSUME_DELIMITER);
299+
return parseValueSourceList(getFlags() | FLAG_CONSUME_DELIMITER);
302300
}
303301

304302
/**
@@ -318,7 +316,7 @@ public List<ValueSource> parseValueSourceList(int flags) throws SyntaxError {
318316
/** Parse an individual ValueSource. */
319317
public ValueSource parseValueSource() throws SyntaxError {
320318
/* consume the delimiter afterward for an external call to parseValueSource */
321-
return parseValueSource(FLAG_DEFAULT | FLAG_CONSUME_DELIMITER);
319+
return parseValueSource(getFlags() | FLAG_CONSUME_DELIMITER);
322320
}
323321

324322
/*
@@ -386,14 +384,11 @@ public Query parseNestedQuery() throws SyntaxError {
386384
*
387385
* @param doConsumeDelimiter whether to consume a delimiter following the ValueSource
388386
*/
389-
@SuppressWarnings("ErroneousBitwiseExpression")
390387
protected ValueSource parseValueSource(boolean doConsumeDelimiter) throws SyntaxError {
391-
// @SuppressWarnings("ErroneousBitwiseExpression") is needed since
392-
// FLAG_DEFAULT & ~FLAG_CONSUME_DELIMITER == 0
393388
return parseValueSource(
394389
doConsumeDelimiter
395-
? (FLAG_DEFAULT | FLAG_CONSUME_DELIMITER)
396-
: (FLAG_DEFAULT & ~FLAG_CONSUME_DELIMITER));
390+
? (getFlags() | FLAG_CONSUME_DELIMITER)
391+
: (getFlags() & ~FLAG_CONSUME_DELIMITER));
397392
}
398393

399394
protected ValueSource parseValueSource(int flags) throws SyntaxError {
@@ -430,7 +425,9 @@ protected ValueSource parseValueSource(int flags) throws SyntaxError {
430425
} else {
431426
QParser subParser = subQuery(val, "func");
432427
if (subParser instanceof FunctionQParser) {
433-
((FunctionQParser) subParser).setParseMultipleSources(true);
428+
FunctionQParser subFunc = (FunctionQParser) subParser;
429+
subFunc.setParseMultipleSources(true);
430+
subFunc.setFlags(flags);
434431
}
435432
Query subQuery = subParser.getQuery();
436433
if (subQuery == null) {

solr/core/src/test-files/solr/collection1/conf/schema15.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
<!-- Dense Vector Fields -->
4747
<fieldType name="knn_vector" class="solr.DenseVectorField" vectorDimension="4" similarityFunction="cosine"/>
48+
<fieldType name="knn_vector_byte" class="solr.DenseVectorField" vectorDimension="4" similarityFunction="cosine" vectorEncoding="BYTE" />
4849

4950
<!-- Field type demonstrating an Analyzer failure -->
5051
<fieldType name="failtype1" class="solr.TextField">
@@ -564,6 +565,7 @@
564565

565566
<!-- Dense Vector-->
566567
<field name="vector" type="knn_vector" indexed="true" stored="true"/>
568+
<field name="vector_byte" type="knn_vector_byte" indexed="true" stored="true"/>
567569

568570
<dynamicField name="*_sI" type="string" indexed="true" stored="false"/>
569571
<dynamicField name="*_sS" type="string" indexed="false" stored="true"/>

solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ public void testFuncKnnVector() throws Exception {
915915
req(
916916
"v1", "[1,2,3]",
917917
"v2", " [1,2,3] ",
918-
"v3", " [1, 2, 3] ")) {
918+
"v3", " [1, 2, 3.0] ")) {
919919
assertFuncEquals(
920920
req,
921921
"vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
@@ -962,6 +962,28 @@ public void testFuncKnnVector() throws Exception {
962962
"vectorSimilarity(vector, $v2)");
963963
}
964964

965+
try (SolrQueryRequest req =
966+
req(
967+
"f", "vector_byte",
968+
"v1", "[1,2,3,4]",
969+
"v2", " [1, 2, 3, 4]")) {
970+
assertFuncEquals(
971+
req,
972+
"vectorSimilarity(BYTE,COSINE,vector_byte,[1,2,3,4])",
973+
"vectorSimilarity(BYTE,COSINE,vector_byte,$v1)",
974+
"vectorSimilarity(BYTE,COSINE,vector_byte, $v1)",
975+
"vectorSimilarity(BYTE,COSINE,vector_byte,$v2)",
976+
"vectorSimilarity(BYTE,COSINE,vector_byte, $v2)",
977+
"vectorSimilarity(vector_byte,[1,2,3,4])",
978+
"vectorSimilarity( vector_byte,[1,2,3,4])",
979+
"vectorSimilarity( $f,[1,2,3,4])",
980+
"vectorSimilarity(vector_byte,$v1)",
981+
"vectorSimilarity(vector_byte, $v1)",
982+
"vectorSimilarity( $f, $v1)",
983+
"vectorSimilarity(vector_byte,$v2)",
984+
"vectorSimilarity(vector_byte, $v2)");
985+
}
986+
965987
// contrived, but helps us test the param resolution
966988
// for both field names in the 2arg usecase
967989
try (SolrQueryRequest req = req("f", "vector")) {

solr/core/src/test/org/apache/solr/search/function/TestDenseVectorFunctionQuery.java

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.solr.common.SolrException;
2424
import org.apache.solr.common.SolrInputDocument;
2525
import org.apache.solr.common.params.CommonParams;
26+
import org.apache.solr.common.params.SolrParams;
2627
import org.junit.After;
2728
import org.junit.Before;
2829
import org.junit.Test;
@@ -285,46 +286,48 @@ public void testReportsErrorInvalidArgs() {
285286

286287
@Test
287288
public void test2ArgsByteFieldAndConstVector() throws Exception {
288-
assertQ(
289-
req(
290-
CommonParams.Q,
291-
"{!func} vectorSimilarity(vector_byte_encoding, [1,2,3,3])",
292-
"fq",
293-
"id:(1 2)",
294-
"fl",
295-
"id, score",
296-
"rows",
297-
"1"),
298-
"//result[@numFound='" + 2 + "']",
299-
"//result/doc[1]/str[@name='id'][.=1]");
300-
assertQ(
301-
req(
302-
CommonParams.Q,
303-
"{!func} vectorSimilarity(vector_byte_encoding, [3,3,2,1])",
304-
"fq",
305-
"id:(1 2)",
306-
"fl",
307-
"id, score",
308-
"rows",
309-
"1"),
310-
"//result[@numFound='" + 2 + "']",
311-
"//result/doc[1]/str[@name='id'][.=2]");
289+
for (SolrParams main :
290+
Arrays.asList(
291+
params(CommonParams.Q, "{!func} vectorSimilarity(vector_byte_encoding, [1,2,3,3])"),
292+
params(
293+
CommonParams.Q,
294+
"{!func} vectorSimilarity(vector_byte_encoding, $vec)",
295+
"vec",
296+
"[1,2,3,3]"))) {
297+
assertQ(
298+
req(main, "fq", "id:(1 2)", "fl", "id, score", "rows", "1"),
299+
"//result[@numFound='" + 2 + "']",
300+
"//result/doc[1]/str[@name='id'][.=1]");
301+
}
302+
for (SolrParams main :
303+
Arrays.asList(
304+
params(CommonParams.Q, "{!func} vectorSimilarity(vector_byte_encoding, [3,3,2,1])"),
305+
params(
306+
CommonParams.Q,
307+
"{!func} vectorSimilarity(vector_byte_encoding, $vec)",
308+
"vec",
309+
"[3,3,2,1]"))) {
310+
311+
assertQ(
312+
req(main, "fq", "id:(1 2)", "fl", "id, score", "rows", "1"),
313+
"//result[@numFound='" + 2 + "']",
314+
"//result/doc[1]/str[@name='id'][.=2]");
315+
}
312316
}
313317

314318
@Test
315319
public void test2ArgsFloatFieldAndConstVector() throws Exception {
316-
assertQ(
317-
req(
318-
CommonParams.Q,
319-
"{!func} vectorSimilarity(vector, [1,2,3,3])",
320-
"fq",
321-
"id:(1 2 3)",
322-
"fl",
323-
"id, score"),
324-
"//result[@numFound='" + 3 + "']",
325-
"//result/doc[1]/str[@name='id'][.=2]",
326-
"//result/doc[2]/str[@name='id'][.=3]",
327-
"//result/doc[3]/str[@name='id'][.=1]");
320+
for (SolrParams main :
321+
Arrays.asList(
322+
params(CommonParams.Q, "{!func} vectorSimilarity(vector, [1,2,3,3])"),
323+
params(CommonParams.Q, "{!func} vectorSimilarity(vector, $vec)", "vec", "[1,2,3,3]"))) {
324+
assertQ(
325+
req(main, "fq", "id:(1 2 3)", "fl", "id, score"),
326+
"//result[@numFound='" + 3 + "']",
327+
"//result/doc[1]/str[@name='id'][.=2]",
328+
"//result/doc[2]/str[@name='id'][.=3]",
329+
"//result/doc[3]/str[@name='id'][.=1]");
330+
}
328331
}
329332

330333
@Test

0 commit comments

Comments
 (0)