Skip to content

Commit c7d344a

Browse files
benwtrentjoshua-adams-1
authored andcommitted
Remove unnecessary knn param boxing and simplyfing somethings (elastic#128693)
some refactoring I noticed recently that we can do. Now k is always provided we can remove boxing and this simplifies some logic. Additionally, modernizes and simplifies some tests. No behavior change in this PR.
1 parent 3083ce0 commit c7d344a

File tree

8 files changed

+67
-64
lines changed

8 files changed

+67
-64
lines changed

server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
import org.elasticsearch.search.profile.query.QueryProfiler;
1818

1919
public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider {
20-
private final Integer kParam;
20+
private final int kParam;
2121
private long vectorOpsCount;
2222

2323
public ESDiversifyingChildrenByteKnnVectorQuery(
2424
String field,
2525
byte[] query,
2626
Query childFilter,
27-
Integer k,
27+
int k,
2828
int numCands,
2929
BitSetProducer parentsFilter,
3030
KnnSearchStrategy strategy
@@ -35,7 +35,7 @@ public ESDiversifyingChildrenByteKnnVectorQuery(
3535

3636
@Override
3737
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
38-
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
38+
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
3939
vectorOpsCount = topK.totalHits.value();
4040
return topK;
4141
}

server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
import org.elasticsearch.search.profile.query.QueryProfiler;
1818

1919
public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider {
20-
private final Integer kParam;
20+
private final int kParam;
2121
private long vectorOpsCount;
2222

2323
public ESDiversifyingChildrenFloatKnnVectorQuery(
2424
String field,
2525
float[] query,
2626
Query childFilter,
27-
Integer k,
27+
int k,
2828
int numCands,
2929
BitSetProducer parentsFilter,
3030
KnnSearchStrategy strategy
@@ -35,7 +35,7 @@ public ESDiversifyingChildrenFloatKnnVectorQuery(
3535

3636
@Override
3737
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
38-
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
38+
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
3939
vectorOpsCount = topK.totalHits.value();
4040
return topK;
4141
}

server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@
1616
import org.elasticsearch.search.profile.query.QueryProfiler;
1717

1818
public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
19-
private final Integer kParam;
19+
private final int kParam;
2020
private long vectorOpsCount;
2121

22-
public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
22+
public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
2323
super(field, target, numCands, filter, strategy);
2424
this.kParam = k;
2525
}
2626

2727
@Override
2828
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
2929
// if k param is set, we get only top k results from each shard
30-
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
30+
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
3131
vectorOpsCount = topK.totalHits.value();
3232
return topK;
3333
}

server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@
1616
import org.elasticsearch.search.profile.query.QueryProfiler;
1717

1818
public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
19-
private final Integer kParam;
19+
private final int kParam;
2020
private long vectorOpsCount;
2121

22-
public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
22+
public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
2323
super(field, target, numCands, filter, strategy);
2424
this.kParam = k;
2525
}
2626

2727
@Override
2828
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
2929
// if k param is set, we get only top k results from each shard
30-
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
30+
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
3131
vectorOpsCount = topK.totalHits.value();
3232
return topK;
3333
}
@@ -37,7 +37,7 @@ public void profile(QueryProfiler queryProfiler) {
3737
queryProfiler.addVectorOpsCount(vectorOpsCount);
3838
}
3939

40-
public Integer kParam() {
40+
public int kParam() {
4141
return kParam;
4242
}
4343

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ private static void checkRescoreQueryParameters(
712712
int k,
713713
int candidates,
714714
float oversample,
715-
Integer expectedK,
715+
int expectedK,
716716
int expectedCandidates,
717717
int expectedResults
718718
) {

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,16 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
187187
assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity()));
188188
query = ((VectorSimilarityQuery) query).getInnerKnnQuery();
189189
}
190-
Integer k = queryBuilder.k();
191-
if (k == null) {
190+
int k;
191+
if (queryBuilder.k() == null) {
192192
k = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
193+
} else {
194+
k = queryBuilder.k();
193195
}
194196
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
195197
if (queryBuilder.rescoreVectorBuilder().oversample() > 0) {
196198
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
197-
assertEquals(k.intValue(), (rescoreQuery.k()));
199+
assertEquals(k, (rescoreQuery.k()));
198200
query = rescoreQuery.innerQuery();
199201
} else {
200202
assertFalse(query instanceof RescoreKnnVectorQuery);
@@ -213,7 +215,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
213215
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
214216
Integer numCands = queryBuilder.numCands();
215217
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
216-
Float oversample = queryBuilder.rescoreVectorBuilder().oversample();
218+
float oversample = queryBuilder.rescoreVectorBuilder().oversample();
217219
k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample));
218220
numCands = Math.max(numCands, k);
219221
}

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -110,93 +110,94 @@ protected KnnSearchBuilder createTestInstance() {
110110

111111
@Override
112112
protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) {
113-
switch (random().nextInt(8)) {
114-
case 0:
113+
return switch (random().nextInt(8)) {
114+
case 0 -> {
115115
String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5));
116-
return new KnnSearchBuilder(
116+
yield new KnnSearchBuilder(
117117
newField,
118118
instance.queryVector,
119119
instance.k,
120120
instance.numCands,
121121
instance.getRescoreVectorBuilder(),
122122
instance.similarity
123123
).boost(instance.boost);
124-
case 1:
124+
}
125+
case 1 -> {
125126
float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5));
126-
return new KnnSearchBuilder(
127+
yield new KnnSearchBuilder(
127128
instance.field,
128129
newVector,
129130
instance.k,
130131
instance.numCands,
131132
instance.getRescoreVectorBuilder(),
132133
instance.similarity
133134
).boost(instance.boost);
134-
case 2:
135+
}
136+
case 2 -> {
135137
// given how the test instance is created, we have a 20-value gap between `k` and `numCands` so we SHOULD be safe
136138
Integer newK = randomValueOtherThan(instance.k, () -> instance.k + ESTestCase.randomInt(10));
137-
return new KnnSearchBuilder(
139+
yield new KnnSearchBuilder(
138140
instance.field,
139141
instance.queryVector,
140142
newK,
141143
instance.numCands,
142144
instance.getRescoreVectorBuilder(),
143145
instance.similarity
144146
).boost(instance.boost);
145-
case 3:
147+
}
148+
case 3 -> {
146149
Integer newNumCands = randomValueOtherThan(instance.numCands, () -> instance.numCands + ESTestCase.randomInt(100));
147-
return new KnnSearchBuilder(
150+
yield new KnnSearchBuilder(
148151
instance.field,
149152
instance.queryVector,
150153
instance.k,
151154
newNumCands,
152155
instance.getRescoreVectorBuilder(),
153156
instance.similarity
154157
).boost(instance.boost);
155-
case 4:
156-
return new KnnSearchBuilder(
157-
instance.field,
158-
instance.queryVector,
159-
instance.k,
160-
instance.numCands,
161-
instance.getRescoreVectorBuilder(),
162-
instance.similarity
163-
).addFilterQueries(instance.filterQueries)
164-
.addFilterQuery(QueryBuilders.termQuery("new_field", "new-value"))
165-
.boost(instance.boost);
166-
case 5:
158+
}
159+
case 4 -> new KnnSearchBuilder(
160+
instance.field,
161+
instance.queryVector,
162+
instance.k,
163+
instance.numCands,
164+
instance.getRescoreVectorBuilder(),
165+
instance.similarity
166+
).addFilterQueries(instance.filterQueries)
167+
.addFilterQuery(QueryBuilders.termQuery("new_field", "new-value"))
168+
.boost(instance.boost);
169+
case 5 -> {
167170
float newBoost = randomValueOtherThan(instance.boost, ESTestCase::randomFloat);
168-
return new KnnSearchBuilder(
171+
yield new KnnSearchBuilder(
169172
instance.field,
170173
instance.queryVector,
171174
instance.k,
172175
instance.numCands,
173176
instance.getRescoreVectorBuilder(),
174177
instance.similarity
175178
).addFilterQueries(instance.filterQueries).boost(newBoost);
176-
case 6:
177-
return new KnnSearchBuilder(
178-
instance.field,
179-
instance.queryVector,
180-
instance.k,
181-
instance.numCands,
179+
}
180+
case 6 -> new KnnSearchBuilder(
181+
instance.field,
182+
instance.queryVector,
183+
instance.k,
184+
instance.numCands,
185+
instance.getRescoreVectorBuilder(),
186+
randomValueOtherThan(instance.similarity, ESTestCase::randomFloat)
187+
).addFilterQueries(instance.filterQueries).boost(instance.boost);
188+
case 7 -> new KnnSearchBuilder(
189+
instance.field,
190+
instance.queryVector,
191+
instance.k,
192+
instance.numCands,
193+
randomValueOtherThan(
182194
instance.getRescoreVectorBuilder(),
183-
randomValueOtherThan(instance.similarity, ESTestCase::randomFloat)
184-
).addFilterQueries(instance.filterQueries).boost(instance.boost);
185-
case 7:
186-
return new KnnSearchBuilder(
187-
instance.field,
188-
instance.queryVector,
189-
instance.k,
190-
instance.numCands,
191-
randomValueOtherThan(
192-
instance.getRescoreVectorBuilder(),
193-
() -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false))
194-
),
195-
instance.similarity
196-
).addFilterQueries(instance.filterQueries).boost(instance.boost);
197-
default:
198-
throw new IllegalStateException();
199-
}
195+
() -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false))
196+
),
197+
instance.similarity
198+
).addFilterQueries(instance.filterQueries).boost(instance.boost);
199+
default -> throw new IllegalStateException();
200+
};
200201
}
201202

202203
public void testToQueryBuilder() {

server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static class TestQueryVectorBuilder implements QueryVectorBuilder {
4545
PARSER.declareFloatArray(ConstructingObjectParser.constructorArg(), QUERY_VECTOR);
4646
}
4747

48-
private List<Float> vectorToBuild;
48+
private final List<Float> vectorToBuild;
4949

5050
public TestQueryVectorBuilder(List<Float> vectorToBuild) {
5151
this.vectorToBuild = vectorToBuild;

0 commit comments

Comments
 (0)