Skip to content

Commit a01e4bc

Browse files
committed
Add SemanticQueryBuilder parsing and tests
1 parent 7c4bcf4 commit a01e4bc

File tree

3 files changed

+68
-18
lines changed

3 files changed

+68
-18
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,8 +2167,8 @@ public void checkPushDownComplexNegationsToPrefilter(
21672167
String query = String.format(Locale.ROOT, """
21682168
from test
21692169
| where ((%s
2170-
or NOT integer > 10) and NOT ((keyword == "test")
2171-
or %s))
2170+
or NOT integer > 10) and NOT ((keyword == "test")
2171+
or %s))
21722172
""", esqlQuery1, esqlQuery2);
21732173
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
21742174

@@ -2182,20 +2182,20 @@ public void checkPushDownComplexNegationsToPrefilter(
21822182
query,
21832183
unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))),
21842184
"keyword",
2185-
new Source(3, 32, "keyword == \"test\"")
2185+
new Source(3, 35, "keyword == \"test\"")
21862186
);
21872187
QueryBuilder notKeywordFilter = wrapWithSingleQuery(
21882188
query,
21892189
unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))),
21902190
"keyword",
2191-
new Source(3, 26, "NOT ((keyword == \"test\")\n or " + esqlQuery1 + ")")
2191+
new Source(3, 29, "NOT ((keyword == \"test\")\n or " + esqlQuery2 + ")")
21922192
);
21932193

21942194
QueryBuilder notIntegerGt10 = wrapWithSingleQuery(
21952195
query,
21962196
unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))),
21972197
"integer",
2198-
new Source(3, 4, "NOT integer > 10")
2198+
new Source(3, 7, "NOT integer > 10")
21992199
);
22002200

22012201
expectedQueryBuilder1.addFilterQueries(List.of(notKeywordFilter));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.inference.TaskType;
3030
import org.elasticsearch.search.vectors.FilteredQueryBuilder;
3131
import org.elasticsearch.xcontent.ConstructingObjectParser;
32+
import org.elasticsearch.xcontent.ObjectParser;
3233
import org.elasticsearch.xcontent.ParseField;
3334
import org.elasticsearch.xcontent.XContentBuilder;
3435
import org.elasticsearch.xcontent.XContentParser;
@@ -77,6 +78,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
7778
private static final ParseField FIELD_FIELD = new ParseField("field");
7879
private static final ParseField QUERY_FIELD = new ParseField("query");
7980
private static final ParseField LENIENT_FIELD = new ParseField("lenient");
81+
private static final ParseField FILTER_FIELD = new ParseField("filter");
8082

8183
private static final ConstructingObjectParser<SemanticQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
8284
NAME,
@@ -92,6 +94,12 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
9294
PARSER.declareString(constructorArg(), QUERY_FIELD);
9395
PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD);
9496
declareStandardFields(PARSER);
97+
PARSER.declareFieldArray(
98+
SemanticQueryBuilder::addFilterQueries,
99+
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
100+
FILTER_FIELD,
101+
ObjectParser.ValueType.OBJECT_ARRAY
102+
);
95103
}
96104

97105
private final String fieldName;
@@ -359,6 +367,13 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
359367
if (lenient != null) {
360368
builder.field(LENIENT_FIELD.getPreferredName(), lenient);
361369
}
370+
if (filterQueries.isEmpty() == false) {
371+
builder.startArray(FILTER_FIELD.getPreferredName());
372+
for (QueryBuilder filterQuery : filterQueries) {
373+
filterQuery.toXContent(builder, params);
374+
}
375+
builder.endArray();
376+
}
362377
boostAndQueryNameToXContent(builder);
363378
builder.endObject();
364379
}
@@ -432,6 +447,10 @@ public SemanticQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries) {
432447
return this;
433448
}
434449

450+
public List<QueryBuilder> getFilterQueries() {
451+
return Collections.unmodifiableList(filterQueries);
452+
}
453+
435454
private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
436455
ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
437456
if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
@@ -452,7 +471,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu
452471
// The inference results map is fully populated, so we can perform error checking
453472
inferenceResultsErrorCheck(modifiedInferenceResultsMap);
454473
} else {
455-
rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap).addFilterQueries(filterQueries);
474+
rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap);
456475
}
457476
}
458477

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
4747
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
4848
import org.elasticsearch.index.query.QueryBuilder;
49+
import org.elasticsearch.index.query.QueryBuilders;
4950
import org.elasticsearch.index.query.QueryRewriteContext;
5051
import org.elasticsearch.index.query.SearchExecutionContext;
5152
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
@@ -241,6 +242,15 @@ protected SemanticQueryBuilder doCreateTestQueryBuilder() {
241242
if (randomBoolean()) {
242243
builder.queryName(randomAlphaOfLength(4));
243244
}
245+
if (randomBoolean()) {
246+
List<QueryBuilder> filters = new ArrayList<>();
247+
int numFilters = randomIntBetween(1, 5);
248+
for (int i = 0; i < numFilters; i++) {
249+
String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME;
250+
filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10)));
251+
}
252+
builder.addFilterQueries(filters);
253+
}
244254

245255
return builder;
246256
}
@@ -256,13 +266,13 @@ protected void doAssertLuceneQuery(SemanticQueryBuilder queryBuilder, Query quer
256266

257267
switch (inferenceResultType) {
258268
case NONE -> assertThat(nestedQuery.getChildQuery(), instanceOf(MatchNoDocsQuery.class));
259-
case SPARSE_EMBEDDING -> assertSparseEmbeddingLuceneQuery(nestedQuery.getChildQuery());
260-
case TEXT_EMBEDDING -> assertTextEmbeddingLuceneQuery(nestedQuery.getChildQuery());
269+
case SPARSE_EMBEDDING -> assertSparseEmbeddingLuceneQuery(queryBuilder, nestedQuery.getChildQuery());
270+
case TEXT_EMBEDDING -> assertTextEmbeddingLuceneQuery(queryBuilder, nestedQuery.getChildQuery());
261271
}
262272
}
263273

264-
private void assertSparseEmbeddingLuceneQuery(Query query) {
265-
Query innerQuery = assertOuterBooleanQuery(query);
274+
private void assertSparseEmbeddingLuceneQuery(SemanticQueryBuilder queryBuilder, Query query) {
275+
Query innerQuery = assertOuterBooleanQuery(query, queryBuilder);
266276
assertThat(innerQuery, instanceOf(SparseVectorQueryWrapper.class));
267277
var sparseQuery = (SparseVectorQueryWrapper) innerQuery;
268278
assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
@@ -272,8 +282,8 @@ private void assertSparseEmbeddingLuceneQuery(Query query) {
272282
assertThat(innerBooleanQuery.clauses().size(), equalTo(0));
273283
}
274284

275-
private void assertTextEmbeddingLuceneQuery(Query query) {
276-
Query innerQuery = assertOuterBooleanQuery(query);
285+
private void assertTextEmbeddingLuceneQuery(SemanticQueryBuilder queryBuilder, Query query) {
286+
Query innerQuery = assertOuterBooleanQuery(query, queryBuilder);
277287

278288
Class<? extends Query> expectedKnnQueryClass = switch (denseVectorElementType) {
279289
case FLOAT -> KnnFloatVectorQuery.class;
@@ -282,12 +292,38 @@ private void assertTextEmbeddingLuceneQuery(Query query) {
282292
assertThat(innerQuery, instanceOf(expectedKnnQueryClass));
283293
}
284294

285-
private Query assertOuterBooleanQuery(Query query) {
295+
private Query assertOuterBooleanQuery(Query query, SemanticQueryBuilder queryBuilder) {
286296
assertThat(query, instanceOf(BooleanQuery.class));
287297
BooleanQuery outerBooleanQuery = (BooleanQuery) query;
288298

289299
List<BooleanClause> outerMustClauses = new ArrayList<>();
290300
List<BooleanClause> outerFilterClauses = new ArrayList<>();
301+
retrieveMustAndFilterClauses(outerBooleanQuery, outerMustClauses, outerFilterClauses);
302+
303+
assertThat(outerMustClauses.size(), equalTo(1));
304+
305+
int expectedFilterClauses = 1;
306+
if (inferenceResultType == InferenceResultType.SPARSE_EMBEDDING) {
307+
// Outer must clause contains query builder filters and the must clause
308+
outerBooleanQuery = (BooleanQuery) outerMustClauses.get(0).query();
309+
outerMustClauses.clear();
310+
outerFilterClauses.clear();
311+
312+
retrieveMustAndFilterClauses(outerBooleanQuery, outerMustClauses, outerFilterClauses);
313+
314+
assertThat(outerMustClauses.size(), equalTo(1));
315+
expectedFilterClauses = queryBuilder.getFilterQueries().size();
316+
}
317+
assertThat(outerFilterClauses.size(), equalTo(expectedFilterClauses));
318+
319+
return outerMustClauses.get(0).query();
320+
}
321+
322+
private static void retrieveMustAndFilterClauses(
323+
BooleanQuery outerBooleanQuery,
324+
List<BooleanClause> outerMustClauses,
325+
List<BooleanClause> outerFilterClauses
326+
) {
291327
for (BooleanClause clause : outerBooleanQuery.clauses()) {
292328
BooleanClause.Occur occur = clause.occur();
293329
if (occur == MUST) {
@@ -298,11 +334,6 @@ private Query assertOuterBooleanQuery(Query query) {
298334
fail("Unexpected boolean " + occur + " clause");
299335
}
300336
}
301-
302-
assertThat(outerMustClauses.size(), equalTo(1));
303-
assertThat(outerFilterClauses.size(), equalTo(1));
304-
305-
return outerMustClauses.get(0).query();
306337
}
307338

308339
@Override

0 commit comments

Comments
 (0)