Skip to content

Commit ffa423d

Browse files
markjhoyelasticsearchmachineelasticmachine
authored
Move SparseVector toDoQuery Finalization into Sparse Vector Field Mapping FieldType (#129020)
* move SparseVector toDoQuery final into field type * [CI] Auto commit changes from spotless * change method name to be specific to sparse vector --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Elastic Machine <[email protected]>
1 parent 0a5dfc0 commit ffa423d

File tree

39 files changed

+96
-59
lines changed

39 files changed

+96
-59
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import org.elasticsearch.index.mapper.TextSearchInfo;
3737
import org.elasticsearch.index.mapper.ValueFetcher;
3838
import org.elasticsearch.index.query.SearchExecutionContext;
39+
import org.elasticsearch.inference.WeightedToken;
40+
import org.elasticsearch.inference.WeightedTokensUtils;
3941
import org.elasticsearch.search.fetch.StoredFieldsSpec;
4042
import org.elasticsearch.search.lookup.Source;
4143
import org.elasticsearch.xcontent.XContentBuilder;
@@ -149,6 +151,18 @@ public Query existsQuery(SearchExecutionContext context) {
149151
return super.existsQuery(context);
150152
}
151153

154+
public Query finalizeSparseVectorQuery(
155+
SearchExecutionContext context,
156+
String fieldName,
157+
List<WeightedToken> queryVectors,
158+
boolean shouldPruneTokens,
159+
TokenPruningConfig tokenPruningConfig
160+
) throws IOException {
161+
return (shouldPruneTokens)
162+
? WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, queryVectors, this, context)
163+
: WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, queryVectors, this, context);
164+
}
165+
152166
private static String indexedValueForSearch(Object value) {
153167
if (value instanceof BytesRef) {
154168
return ((BytesRef) value).utf8ToString();
Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/*
22
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3-
* or more contributor license agreements. Licensed under the Elastic License
4-
* 2.0; you may not use this file except in compliance with the Elastic License
5-
* 2.0.
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
68
*/
79

8-
package org.elasticsearch.xpack.core.ml.search;
10+
package org.elasticsearch.index.mapper.vectors;
911

1012
import org.elasticsearch.common.ParsingException;
1113
import org.elasticsearch.common.io.stream.StreamInput;
@@ -22,9 +24,8 @@
2224
import java.util.Objects;
2325
import java.util.Set;
2426

25-
import static org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder.PRUNING_CONFIG;
26-
2727
public class TokenPruningConfig implements Writeable, ToXContentObject {
28+
public static final String PRUNING_CONFIG_FIELD = "pruning_config";
2829
public static final ParseField TOKENS_FREQ_RATIO_THRESHOLD = new ParseField("tokens_freq_ratio_threshold");
2930
public static final ParseField TOKENS_WEIGHT_THRESHOLD = new ParseField("tokens_weight_threshold");
3031
public static final ParseField ONLY_SCORE_PRUNED_TOKENS_FIELD = new ParseField("only_score_pruned_tokens");
@@ -150,7 +151,7 @@ public static TokenPruningConfig fromXContent(XContentParser parser) throws IOEx
150151
).contains(currentFieldName) == false) {
151152
throw new ParsingException(
152153
parser.getTokenLocation(),
153-
"[" + PRUNING_CONFIG.getPreferredName() + "] unknown token [" + currentFieldName + "]"
154+
"[" + PRUNING_CONFIG_FIELD + "] unknown token [" + currentFieldName + "]"
154155
);
155156
}
156157
} else if (token.isValue()) {
@@ -163,13 +164,13 @@ public static TokenPruningConfig fromXContent(XContentParser parser) throws IOEx
163164
} else {
164165
throw new ParsingException(
165166
parser.getTokenLocation(),
166-
"[" + PRUNING_CONFIG.getPreferredName() + "] does not support [" + currentFieldName + "]"
167+
"[" + PRUNING_CONFIG_FIELD + "] does not support [" + currentFieldName + "]"
167168
);
168169
}
169170
} else {
170171
throw new ParsingException(
171172
parser.getTokenLocation(),
172-
"[" + PRUNING_CONFIG.getPreferredName() + "] unknown token [" + token + "] after [" + currentFieldName + "]"
173+
"[" + PRUNING_CONFIG_FIELD + "] unknown token [" + token + "] after [" + currentFieldName + "]"
173174
);
174175
}
175176
}
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/*
22
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3-
* or more contributor license agreements. Licensed under the Elastic License
4-
* 2.0; you may not use this file except in compliance with the Elastic License
5-
* 2.0.
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
68
*/
79

8-
package org.elasticsearch.xpack.core.ml.search;
10+
package org.elasticsearch.inference;
911

1012
import org.elasticsearch.common.Strings;
1113
import org.elasticsearch.common.io.stream.StreamInput;
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/*
22
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3-
* or more contributor license agreements. Licensed under the Elastic License
4-
* 2.0; you may not use this file except in compliance with the Elastic License
5-
* 2.0.
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
68
*/
79

8-
package org.elasticsearch.xpack.core.ml.search;
10+
package org.elasticsearch.inference;
911

1012
import org.apache.lucene.index.IndexReader;
1113
import org.apache.lucene.index.Term;
@@ -15,7 +17,9 @@
1517
import org.apache.lucene.search.MatchNoDocsQuery;
1618
import org.apache.lucene.search.Query;
1719
import org.elasticsearch.index.mapper.MappedFieldType;
20+
import org.elasticsearch.index.mapper.vectors.TokenPruningConfig;
1821
import org.elasticsearch.index.query.SearchExecutionContext;
22+
import org.elasticsearch.search.vectors.SparseVectorQueryWrapper;
1923

2024
import java.io.IOException;
2125
import java.util.List;
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
/*
22
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3-
* or more contributor license agreements. Licensed under the Elastic License
4-
* 2.0; you may not use this file except in compliance with the Elastic License
5-
* 2.0.
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
68
*/
79

8-
package org.elasticsearch.xpack.core.ml.search;
10+
package org.elasticsearch.search.vectors;
911

1012
import org.apache.lucene.search.BooleanClause;
1113
import org.apache.lucene.search.IndexSearcher;
1214
import org.apache.lucene.search.Query;
1315
import org.apache.lucene.search.QueryVisitor;
1416
import org.apache.lucene.search.ScoreMode;
1517
import org.apache.lucene.search.Weight;
16-
import org.elasticsearch.index.query.SearchExecutionContext;
1718

1819
import java.io.IOException;
1920
import java.util.Objects;
2021

2122
/**
22-
* A wrapper class for the Lucene query generated by {@link SparseVectorQueryBuilder#toQuery(SearchExecutionContext)}.
23+
* A wrapper class for the Lucene query generated by SparseVectorQueryBuilder#toQuery(SearchExecutionContext)
24+
* (found in x-pack/core/ml/search).
2325
* This wrapper facilitates the extraction of the complete sparse vector query using a {@link QueryVisitor}.
2426
*/
2527
public class SparseVectorQueryWrapper extends Query {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
1717
import org.elasticsearch.inference.InferenceResults;
1818
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.inference.WeightedToken;
1920
import org.elasticsearch.rest.RestStatus;
2021
import org.elasticsearch.xcontent.ToXContent;
2122
import org.elasticsearch.xcontent.ToXContentObject;
2223
import org.elasticsearch.xcontent.XContent;
2324
import org.elasticsearch.xcontent.XContentBuilder;
2425
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
25-
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
2626

2727
import java.io.IOException;
2828
import java.util.ArrayList;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/MlChunkedTextExpansionResults.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
1212
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.inference.WeightedToken;
1314
import org.elasticsearch.xcontent.ToXContentObject;
1415
import org.elasticsearch.xcontent.XContentBuilder;
15-
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
1616

1717
import java.io.IOException;
1818
import java.util.HashMap;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.inference.WeightedToken;
1213
import org.elasticsearch.xcontent.XContentBuilder;
13-
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
1414

1515
import java.io.IOException;
1616
import java.util.List;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
import org.elasticsearch.common.io.stream.StreamOutput;
1919
import org.elasticsearch.core.Nullable;
2020
import org.elasticsearch.index.mapper.MappedFieldType;
21+
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
22+
import org.elasticsearch.index.mapper.vectors.TokenPruningConfig;
2123
import org.elasticsearch.index.query.AbstractQueryBuilder;
2224
import org.elasticsearch.index.query.QueryBuilder;
2325
import org.elasticsearch.index.query.QueryRewriteContext;
2426
import org.elasticsearch.index.query.SearchExecutionContext;
2527
import org.elasticsearch.inference.InferenceResults;
28+
import org.elasticsearch.inference.WeightedToken;
2629
import org.elasticsearch.xcontent.ConstructingObjectParser;
2730
import org.elasticsearch.xcontent.ParseField;
2831
import org.elasticsearch.xcontent.XContentBuilder;
@@ -215,16 +218,13 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
215218
return new MatchNoDocsQuery("The \"" + getName() + "\" query is against a field that does not exist");
216219
}
217220

218-
final String fieldTypeName = ft.typeName();
219-
if (fieldTypeName.equals(ALLOWED_FIELD_TYPE) == false) {
220-
throw new IllegalArgumentException(
221-
"field [" + fieldName + "] must be type [" + ALLOWED_FIELD_TYPE + "] but is type [" + fieldTypeName + "]"
222-
);
221+
if (ft instanceof SparseVectorFieldMapper.SparseVectorFieldType svft) {
222+
return svft.finalizeSparseVectorQuery(context, fieldName, queryVectors, shouldPruneTokens, tokenPruningConfig);
223223
}
224224

225-
return (shouldPruneTokens)
226-
? WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, queryVectors, ft, context)
227-
: WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, queryVectors, ft, context);
225+
throw new IllegalArgumentException(
226+
"field [" + fieldName + "] must be type [" + ALLOWED_FIELD_TYPE + "] but is type [" + ft.typeName() + "]"
227+
);
228228
}
229229

230230
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.common.logging.DeprecationCategory;
1919
import org.elasticsearch.common.logging.DeprecationLogger;
2020
import org.elasticsearch.core.Nullable;
21+
import org.elasticsearch.index.mapper.vectors.TokenPruningConfig;
2122
import org.elasticsearch.index.query.AbstractQueryBuilder;
2223
import org.elasticsearch.index.query.QueryBuilder;
2324
import org.elasticsearch.index.query.QueryBuilders;

0 commit comments

Comments
 (0)