Skip to content

Commit f5080a6

Browse files
committed
Small name refactoring, fix adjusting parameters
1 parent a7936da commit f5080a6

File tree

4 files changed

+44
-24
lines changed

4 files changed

+44
-24
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
7171
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
7272
import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
73-
import org.elasticsearch.search.vectors.KnnRescoreVectorQuery;
73+
import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;
7474
import org.elasticsearch.search.vectors.VectorData;
7575
import org.elasticsearch.search.vectors.VectorSimilarityQuery;
7676
import org.elasticsearch.xcontent.ToXContent;
@@ -2096,10 +2096,12 @@ private Query createKnnByteQuery(
20962096
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
20972097
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
20982098
}
2099-
Integer adjustedK = k == null || needsRescore(rescoreOversample) == false
2100-
? null
2101-
: Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2102-
int adjustedNumCands = Math.max(adjustedK == null ? 0 : adjustedK, numCands);
2099+
Integer adjustedK = k;
2100+
int adjustedNumCands = numCands;
2101+
if (needsRescore(rescoreOversample) && adjustedK != null) {
2102+
adjustedK = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2103+
adjustedNumCands = Math.max(adjustedK, numCands);
2104+
}
21032105

21042106
Query knnQuery = parentFilter != null
21052107
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
@@ -2112,7 +2114,7 @@ private Query createKnnByteQuery(
21122114
);
21132115
}
21142116
if (needsRescore(rescoreOversample)) {
2115-
knnQuery = new KnnRescoreVectorQuery(
2117+
knnQuery = new RescoreKnnVectorQuery(
21162118
name(),
21172119
queryVector,
21182120
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE),
@@ -2148,10 +2150,12 @@ && isNotUnitVector(squaredMagnitude)) {
21482150
}
21492151
}
21502152

2151-
Integer adjustedK = k == null || needsRescore(rescoreOversample) == false
2152-
? k
2153-
: Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)));
2154-
int adjustedNumCands = adjustedK == null ? numCands : Math.max(adjustedK, numCands);
2153+
Integer adjustedK = k;
2154+
int adjustedNumCands = numCands;
2155+
if (needsRescore(rescoreOversample) && adjustedK != null) {
2156+
adjustedK = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
2157+
adjustedNumCands = Math.max(adjustedK, numCands);
2158+
}
21552159
Query knnQuery = parentFilter != null
21562160
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
21572161
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);
@@ -2163,7 +2167,7 @@ && isNotUnitVector(squaredMagnitude)) {
21632167
);
21642168
}
21652169
if (needsRescore(rescoreOversample)) {
2166-
knnQuery = new KnnRescoreVectorQuery(
2170+
knnQuery = new RescoreKnnVectorQuery(
21672171
name(),
21682172
queryVector,
21692173
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),

server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
3535
public void profile(QueryProfiler queryProfiler) {
3636
queryProfiler.setVectorOpsCount(vectorOpsCount);
3737
}
38+
39+
public Integer kParam() {
40+
return kParam;
41+
}
3842
}

server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
3535
public void profile(QueryProfiler queryProfiler) {
3636
queryProfiler.setVectorOpsCount(vectorOpsCount);
3737
}
38+
39+
public Integer kParam() {
40+
return kParam;
41+
}
3842
}

server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java renamed to server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,44 +28,44 @@
2828
/**
2929
* Wraps a kNN vector query to rescore the results using the non-quantized vectors
3030
*/
31-
public class KnnRescoreVectorQuery extends Query implements ProfilingQuery {
31+
public class RescoreKnnVectorQuery extends Query implements ProfilingQuery {
3232
private final String fieldName;
3333
private final byte[] byteTarget;
3434
private final float[] floatTarget;
3535
private final VectorSimilarityFunction vectorSimilarityFunction;
3636
private final Integer k;
37-
private final Query vectorQuery;
37+
private final Query innerQuery;
3838

3939
private long vectorOpsCount;
4040

41-
public KnnRescoreVectorQuery(
41+
public RescoreKnnVectorQuery(
4242
String fieldName,
4343
byte[] byteTarget,
4444
VectorSimilarityFunction vectorSimilarityFunction,
4545
Integer k,
46-
Query vectorQuery
46+
Query innerQuery
4747
) {
4848
this.fieldName = fieldName;
4949
this.byteTarget = byteTarget;
5050
this.floatTarget = null;
5151
this.vectorSimilarityFunction = vectorSimilarityFunction;
5252
this.k = k;
53-
this.vectorQuery = vectorQuery;
53+
this.innerQuery = innerQuery;
5454
}
5555

56-
public KnnRescoreVectorQuery(
56+
public RescoreKnnVectorQuery(
5757
String fieldName,
5858
float[] floatTarget,
5959
VectorSimilarityFunction vectorSimilarityFunction,
6060
Integer k,
61-
Query vectorQuery
61+
Query innerQuery
6262
) {
6363
this.fieldName = fieldName;
6464
this.byteTarget = null;
6565
this.floatTarget = floatTarget;
6666
this.vectorSimilarityFunction = vectorSimilarityFunction;
6767
this.k = k;
68-
this.vectorQuery = vectorQuery;
68+
this.innerQuery = innerQuery;
6969
}
7070

7171
@Override
@@ -81,7 +81,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
8181
} else {
8282
valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
8383
}
84-
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(vectorQuery, valueSource);
84+
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
8585
Query query = searcher.rewrite(functionScoreQuery);
8686

8787
if (k == null) {
@@ -103,6 +103,14 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
103103
return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
104104
}
105105

106+
public Query innerQuery() {
107+
return innerQuery;
108+
}
109+
110+
public Integer k() {
111+
return k;
112+
}
113+
106114
@Override
107115
public void profile(QueryProfiler queryProfiler) {
108116
queryProfiler.setVectorOpsCount(vectorOpsCount);
@@ -119,18 +127,18 @@ public void visit(QueryVisitor visitor) {
119127
public boolean equals(Object o) {
120128
if (this == o) return true;
121129
if (o == null || getClass() != o.getClass()) return false;
122-
KnnRescoreVectorQuery that = (KnnRescoreVectorQuery) o;
130+
RescoreKnnVectorQuery that = (RescoreKnnVectorQuery) o;
123131
return Objects.equals(fieldName, that.fieldName)
124132
&& Objects.deepEquals(byteTarget, that.byteTarget)
125133
&& Objects.deepEquals(floatTarget, that.floatTarget)
126134
&& vectorSimilarityFunction == that.vectorSimilarityFunction
127135
&& Objects.equals(k, that.k)
128-
&& Objects.equals(vectorQuery, that.vectorQuery);
136+
&& Objects.equals(innerQuery, that.innerQuery);
129137
}
130138

131139
@Override
132140
public int hashCode() {
133-
return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, vectorQuery);
141+
return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery);
134142
}
135143

136144
@Override
@@ -144,7 +152,7 @@ public String toString(String field) {
144152
}
145153
sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction);
146154
sb.append(", k=").append(k);
147-
sb.append(", vectorQuery=").append(vectorQuery);
155+
sb.append(", vectorQuery=").append(innerQuery);
148156
sb.append('}');
149157
return sb.toString();
150158
}

0 commit comments

Comments
 (0)