Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private final int kParam;
private long vectorOpsCount;

public ESDiversifyingChildrenByteKnnVectorQuery(
String field,
byte[] query,
Query childFilter,
Integer k,
int k,
int numCands,
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
Expand All @@ -35,7 +35,7 @@ public ESDiversifyingChildrenByteKnnVectorQuery(

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
vectorOpsCount = topK.totalHits.value();
return topK;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private final int kParam;
private long vectorOpsCount;

public ESDiversifyingChildrenFloatKnnVectorQuery(
String field,
float[] query,
Query childFilter,
Integer k,
int k,
int numCands,
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
Expand All @@ -35,7 +35,7 @@ public ESDiversifyingChildrenFloatKnnVectorQuery(

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
vectorOpsCount = topK.totalHits.value();
return topK;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private final int kParam;
private long vectorOpsCount;

public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
}

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
// if k param is set, we get only top k results from each shard
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
vectorOpsCount = topK.totalHits.value();
return topK;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private final int kParam;
private long vectorOpsCount;

public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
}

@Override
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
// if k param is set, we get only top k results from each shard
TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
TopDocs topK = TopDocs.merge(kParam, perLeafResults);
vectorOpsCount = topK.totalHits.value();
return topK;
}
Expand All @@ -37,7 +37,7 @@ public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

public Integer kParam() {
public int kParam() {
return kParam;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ private static void checkRescoreQueryParameters(
int k,
int candidates,
float oversample,
Integer expectedK,
int expectedK,
int expectedCandidates,
int expectedResults
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,16 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity()));
query = ((VectorSimilarityQuery) query).getInnerKnnQuery();
}
Integer k = queryBuilder.k();
if (k == null) {
int k;
if (queryBuilder.k() == null) {
k = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
} else {
k = queryBuilder.k();
}
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
if (queryBuilder.rescoreVectorBuilder().oversample() > 0) {
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
assertEquals(k.intValue(), (rescoreQuery.k()));
assertEquals(k, (rescoreQuery.k()));
query = rescoreQuery.innerQuery();
} else {
assertFalse(query instanceof RescoreKnnVectorQuery);
Expand All @@ -213,7 +215,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
Integer numCands = queryBuilder.numCands();
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
Float oversample = queryBuilder.rescoreVectorBuilder().oversample();
float oversample = queryBuilder.rescoreVectorBuilder().oversample();
k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample));
numCands = Math.max(numCands, k);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,93 +110,94 @@ protected KnnSearchBuilder createTestInstance() {

@Override
protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) {
switch (random().nextInt(8)) {
case 0:
return switch (random().nextInt(8)) {
case 0 -> {
String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5));
return new KnnSearchBuilder(
yield new KnnSearchBuilder(
newField,
instance.queryVector,
instance.k,
instance.numCands,
instance.getRescoreVectorBuilder(),
instance.similarity
).boost(instance.boost);
case 1:
}
case 1 -> {
float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5));
return new KnnSearchBuilder(
yield new KnnSearchBuilder(
instance.field,
newVector,
instance.k,
instance.numCands,
instance.getRescoreVectorBuilder(),
instance.similarity
).boost(instance.boost);
case 2:
}
case 2 -> {
// given how the test instance is created, we have a 20-value gap between `k` and `numCands` so we SHOULD be safe
Integer newK = randomValueOtherThan(instance.k, () -> instance.k + ESTestCase.randomInt(10));
return new KnnSearchBuilder(
yield new KnnSearchBuilder(
instance.field,
instance.queryVector,
newK,
instance.numCands,
instance.getRescoreVectorBuilder(),
instance.similarity
).boost(instance.boost);
case 3:
}
case 3 -> {
Integer newNumCands = randomValueOtherThan(instance.numCands, () -> instance.numCands + ESTestCase.randomInt(100));
return new KnnSearchBuilder(
yield new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
newNumCands,
instance.getRescoreVectorBuilder(),
instance.similarity
).boost(instance.boost);
case 4:
return new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
instance.numCands,
instance.getRescoreVectorBuilder(),
instance.similarity
).addFilterQueries(instance.filterQueries)
.addFilterQuery(QueryBuilders.termQuery("new_field", "new-value"))
.boost(instance.boost);
case 5:
}
case 4 -> new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
instance.numCands,
instance.getRescoreVectorBuilder(),
instance.similarity
).addFilterQueries(instance.filterQueries)
.addFilterQuery(QueryBuilders.termQuery("new_field", "new-value"))
.boost(instance.boost);
case 5 -> {
float newBoost = randomValueOtherThan(instance.boost, ESTestCase::randomFloat);
return new KnnSearchBuilder(
yield new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
instance.numCands,
instance.getRescoreVectorBuilder(),
instance.similarity
).addFilterQueries(instance.filterQueries).boost(newBoost);
case 6:
return new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
instance.numCands,
}
case 6 -> new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
instance.numCands,
instance.getRescoreVectorBuilder(),
randomValueOtherThan(instance.similarity, ESTestCase::randomFloat)
).addFilterQueries(instance.filterQueries).boost(instance.boost);
case 7 -> new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
instance.numCands,
randomValueOtherThan(
instance.getRescoreVectorBuilder(),
randomValueOtherThan(instance.similarity, ESTestCase::randomFloat)
).addFilterQueries(instance.filterQueries).boost(instance.boost);
case 7:
return new KnnSearchBuilder(
instance.field,
instance.queryVector,
instance.k,
instance.numCands,
randomValueOtherThan(
instance.getRescoreVectorBuilder(),
() -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false))
),
instance.similarity
).addFilterQueries(instance.filterQueries).boost(instance.boost);
default:
throw new IllegalStateException();
}
() -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false))
),
instance.similarity
).addFilterQueries(instance.filterQueries).boost(instance.boost);
default -> throw new IllegalStateException();
};
}

public void testToQueryBuilder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static class TestQueryVectorBuilder implements QueryVectorBuilder {
PARSER.declareFloatArray(ConstructingObjectParser.constructorArg(), QUERY_VECTOR);
}

private List<Float> vectorToBuild;
private final List<Float> vectorToBuild;

public TestQueryVectorBuilder(List<Float> vectorToBuild) {
this.vectorToBuild = vectorToBuild;
Expand Down