diff --git a/server/src/main/java/org/elasticsearch/index/IndexModule.java b/server/src/main/java/org/elasticsearch/index/IndexModule.java index 42ab9ae362509..e4379f16ac902 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexModule.java +++ b/server/src/main/java/org/elasticsearch/index/IndexModule.java @@ -63,6 +63,7 @@ import org.elasticsearch.indices.recovery.RecoveryState; import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.threadpool.ThreadPool; @@ -496,7 +497,8 @@ public IndexService newIndexService( ValuesSourceRegistry valuesSourceRegistry, IndexStorePlugin.IndexFoldersDeletionListener indexFoldersDeletionListener, Map snapshotCommitSuppliers, - QueryRewriteInterceptor queryRewriteInterceptor + QueryRewriteInterceptor queryRewriteInterceptor, + SimpleQueryRewriter simpleQueryRewriter ) throws IOException { final IndexEventListener eventListener = freeze(); Function> readerWrapperFactory = indexReaderWrapper @@ -561,6 +563,7 @@ public IndexService newIndexService( indexCommitListener.get(), mapperMetrics, queryRewriteInterceptor, + simpleQueryRewriter, indexingStatsSettings, searchStatsSettings, mergeMetrics diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index b5180f70ae845..29d28cae7c2fe 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -91,6 +91,7 @@ import org.elasticsearch.indices.recovery.RecoveryState; import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.threadpool.ThreadPool; @@ -171,6 +172,7 @@ public class IndexService extends AbstractIndexComponent implements IndicesClust private final ValuesSourceRegistry valuesSourceRegistry; private final MapperMetrics mapperMetrics; private final QueryRewriteInterceptor queryRewriteInterceptor; + private final SimpleQueryRewriter simpleQueryRewriter; private final IndexingStatsSettings indexingStatsSettings; private final SearchStatsSettings searchStatsSettings; private final MergeMetrics mergeMetrics; @@ -211,6 +213,7 @@ public IndexService( Engine.IndexCommitListener indexCommitListener, MapperMetrics mapperMetrics, QueryRewriteInterceptor queryRewriteInterceptor, + SimpleQueryRewriter simpleQueryRewriter, IndexingStatsSettings indexingStatsSettings, SearchStatsSettings searchStatsSettings, MergeMetrics mergeMetrics @@ -291,6 +294,7 @@ public IndexService( this.indexCommitListener = indexCommitListener; this.mapperMetrics = mapperMetrics; this.queryRewriteInterceptor = queryRewriteInterceptor; + this.simpleQueryRewriter = simpleQueryRewriter; try (var ignored = threadPool.getThreadContext().clearTraceContext()) { // kick off async ops for the first shard in this index this.refreshTask = new AsyncRefreshTask(this); diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index 8fdc53e6b795f..e80dda992064f 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -146,6 +146,7 @@ import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; @@ -281,7 +282,8 @@ public class IndicesService extends AbstractLifecycleComponent private final PostRecoveryMerger postRecoveryMerger; private final List searchOperationListeners; private final QueryRewriteInterceptor queryRewriteInterceptor; - final SlowLogFieldProvider slowLogFieldProvider; // pkg-private for testingå + private final SimpleQueryRewriter simpleQueryRewriter; + final SlowLogFieldProvider slowLogFieldProvider; // pkg-private for testing private final IndexingStatsSettings indexStatsSettings; private final SearchStatsSettings searchStatsSettings; private final MergeMetrics mergeMetrics; @@ -359,6 +361,7 @@ public void onRemoval(ShardId shardId, String fieldName, boolean wasEvicted, lon this.snapshotCommitSuppliers = builder.snapshotCommitSuppliers; this.requestCacheKeyDifferentiator = builder.requestCacheKeyDifferentiator; this.queryRewriteInterceptor = builder.queryRewriteInterceptor; + this.simpleQueryRewriter = builder.simpleQueryRewriter; this.mapperMetrics = builder.mapperMetrics; this.mergeMetrics = builder.mergeMetrics; // doClose() is called when shutting down a node, yet there might still be ongoing requests @@ -834,7 +837,8 @@ private synchronized IndexService createIndexService( valuesSourceRegistry, indexFoldersDeletionListeners, snapshotCommitSuppliers, - queryRewriteInterceptor + queryRewriteInterceptor, + simpleQueryRewriter ); } @@ -1865,6 +1869,10 @@ public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(Lo ); } + public SimpleQueryRewriter getSimpleQueryRewriter() { + return simpleQueryRewriter; + } + /** * Clears the caches for the given shard id if the shard is still allocated on this node */ diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesServiceBuilder.java b/server/src/main/java/org/elasticsearch/indices/IndicesServiceBuilder.java index 3b7f4d24869f2..be0605e26b675 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesServiceBuilder.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesServiceBuilder.java @@ -37,6 +37,7 @@ import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.plugins.internal.InternalSearchPlugin; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.internal.ShardSearchRequest; @@ -83,6 +84,7 @@ public class IndicesServiceBuilder { MergeMetrics mergeMetrics; List searchOperationListener = List.of(); QueryRewriteInterceptor queryRewriteInterceptor = null; + SimpleQueryRewriter simpleQueryRewriter = null; SlowLogFieldProvider slowLogFieldProvider = new SlowLogFieldProvider() { @Override public SlowLogFields create() { @@ -301,6 +303,27 @@ public IndicesService build() { })); queryRewriteInterceptor = QueryRewriteInterceptor.multi(queryRewriteInterceptors); + var simpleQueryRewriters = pluginsService.filterPlugins(InternalSearchPlugin.class) + .map(InternalSearchPlugin::getSimpleQueryRewriters) + .flatMap(List::stream) + .collect(Collectors.toMap(SimpleQueryRewriter::getName, interceptor -> { + if (interceptor.getName() == null) { + throw new IllegalArgumentException("SimpleQueryRewriter [" + interceptor.getClass().getName() + "] requires name"); + } + return interceptor; + }, (a, b) -> { + throw new IllegalStateException( + "Conflicting simple rewriters [" + + a.getName() + + "] found in [" + + a.getClass().getName() + + "] and [" + + b.getClass().getName() + + "]" + ); + })); + simpleQueryRewriter = SimpleQueryRewriter.multi(simpleQueryRewriters); + return new IndicesService(this); } } diff --git a/server/src/main/java/org/elasticsearch/plugins/internal/InternalSearchPlugin.java b/server/src/main/java/org/elasticsearch/plugins/internal/InternalSearchPlugin.java index 7ac18c4640a0b..a577cf9d99137 100644 --- a/server/src/main/java/org/elasticsearch/plugins/internal/InternalSearchPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/internal/InternalSearchPlugin.java @@ -10,6 +10,7 @@ package org.elasticsearch.plugins.internal; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import java.util.List; @@ -24,4 +25,8 @@ public interface InternalSearchPlugin { default List getQueryRewriteInterceptors() { return emptyList(); } + + default List getSimpleQueryRewriters() { + return emptyList(); + } } diff --git a/server/src/main/java/org/elasticsearch/plugins/internal/rewriter/SimpleQueryRewriter.java b/server/src/main/java/org/elasticsearch/plugins/internal/rewriter/SimpleQueryRewriter.java new file mode 100644 index 0000000000000..838088dc7b7c5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/plugins/internal/rewriter/SimpleQueryRewriter.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.plugins.internal.rewriter; + +import org.elasticsearch.index.query.QueryBuilder; + +import java.util.Map; + +public interface SimpleQueryRewriter { + + String getName(); + + QueryBuilder rewrite(QueryBuilder queryBuilder); + + static SimpleQueryRewriter multi(Map rewriters) { + return rewriters.isEmpty() ? new NoOpSimpleQueryRewriter() : new CompositeSimpleQueryRewriter(rewriters); + } + + class CompositeSimpleQueryRewriter implements SimpleQueryRewriter { + final String NAME = "composite"; + private final Map simpleQueryRewriters; + + private CompositeSimpleQueryRewriter(Map simpleQueryRewriters) { + this.simpleQueryRewriters = simpleQueryRewriters; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public QueryBuilder rewrite(QueryBuilder queryBuilder) { + SimpleQueryRewriter rewriter = simpleQueryRewriters.get(queryBuilder.getName()); + if (rewriter != null) { + return rewriter.rewrite(queryBuilder); + } + return queryBuilder; + } + } + + class NoOpSimpleQueryRewriter implements SimpleQueryRewriter { + @Override + public QueryBuilder rewrite(QueryBuilder queryBuilder) { + return queryBuilder; + } + + @Override + public String getName() { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index c2b526128a9bc..36469e489476e 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -51,6 +51,7 @@ import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.NestedHelper; import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.search.aggregations.SearchContextAggregations; import org.elasticsearch.search.aggregations.support.AggregationContext; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -161,6 +162,7 @@ final class DefaultSearchContext extends SearchContext { private final Map searchExtBuilders = new HashMap<>(); private final SearchExecutionContext searchExecutionContext; private final FetchPhase fetchPhase; + private final SimpleQueryRewriter simpleQueryRewriter; DefaultSearchContext( ReaderContext readerContext, @@ -174,7 +176,8 @@ final class DefaultSearchContext extends SearchContext { SearchService.ResultsType resultsType, boolean enableQueryPhaseParallelCollection, int minimumDocsPerSlice, - long memoryAccountingBufferSize + long memoryAccountingBufferSize, + SimpleQueryRewriter simpleQueryRewriter ) throws IOException { this.readerContext = readerContext; this.request = request; @@ -186,6 +189,7 @@ final class DefaultSearchContext extends SearchContext { this.indexService = readerContext.indexService(); this.indexShard = readerContext.indexShard(); this.memoryAccountingBufferSize = memoryAccountingBufferSize; + this.simpleQueryRewriter = simpleQueryRewriter; Engine.Searcher engineSearcher = readerContext.acquireSearcher("search"); int maximumNumberOfSlices = determineMaximumNumberOfSlices( @@ -976,4 +980,9 @@ public IdLoader newIdLoader() { return IdLoader.fromLeafStoredFieldLoader(); } } + + @Override + public SimpleQueryRewriter simpleQueryRewriter() { + return simpleQueryRewriter; + } } diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 9c228468f1964..21a94390a91dc 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -1402,7 +1402,8 @@ private DefaultSearchContext createSearchContext( resultsType, enableQueryPhaseParallelCollection, minimumDocsPerSlice, - memoryAccountingBufferSize + memoryAccountingBufferSize, + indicesService.getSimpleQueryRewriter() ); // we clone the query shard context here just for rewriting otherwise we // might end up with incorrect state since we are using now() or script services diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 7d018a7ef4ba9..3b0c2893479e7 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -26,6 +26,7 @@ import org.elasticsearch.index.query.QueryShardException; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.search.RescoreDocIds; import org.elasticsearch.search.SearchExtBuilder; import org.elasticsearch.search.SearchShardTarget; @@ -449,4 +450,6 @@ public String toString() { public abstract SourceLoader newSourceLoader(@Nullable SourceFilter sourceFilter); public abstract IdLoader newIdLoader(); + + public abstract SimpleQueryRewriter simpleQueryRewriter(); } diff --git a/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java index f68539289015e..f9b12a4cef71d 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; import org.elasticsearch.index.query.ParsedQuery; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.search.aggregations.SearchContextAggregations; import org.elasticsearch.search.collapse.CollapseContext; import org.elasticsearch.search.fetch.FetchSearchResult; @@ -305,4 +306,9 @@ public TotalHits getTotalHits() { public float getMaxScore() { return querySearchResult.getMaxScore(); } + + @Override + public SimpleQueryRewriter simpleQueryRewriter() { + throw new UnsupportedOperationException("Not supported"); + } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java index c8c6dc942c148..1b2b70461edb1 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.query.ParsedQuery; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.search.SearchExtBuilder; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.SearchContextAggregations; @@ -546,4 +547,9 @@ public SourceLoader newSourceLoader(@Nullable SourceFilter filter) { public IdLoader newIdLoader() { throw new UnsupportedOperationException(); } + + @Override + public SimpleQueryRewriter simpleQueryRewriter() { + return parent.simpleQueryRewriter(); + } } diff --git a/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java b/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java index db29c9a61e007..c96d8121aaeb6 100644 --- a/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java +++ b/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java @@ -239,7 +239,8 @@ private IndexService newIndexService(IndexModule module) throws IOException { null, indexDeletionListener, emptyMap(), - new MockQueryRewriteInterceptor() + new MockQueryRewriteInterceptor(), + null ); } diff --git a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java index 335815fcea445..141bfcf4d042d 100644 --- a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java @@ -190,7 +190,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { randomFrom(SearchService.ResultsType.values()), randomBoolean(), randomInt(), - MEMORY_ACCOUNTING_BUFFER_SIZE + MEMORY_ACCOUNTING_BUFFER_SIZE, + null ); contextWithoutScroll.from(300); contextWithoutScroll.close(); @@ -233,7 +234,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { randomFrom(SearchService.ResultsType.values()), randomBoolean(), randomInt(), - MEMORY_ACCOUNTING_BUFFER_SIZE + MEMORY_ACCOUNTING_BUFFER_SIZE, + null ) ) { @@ -317,7 +319,8 @@ public ScrollContext scrollContext() { randomFrom(SearchService.ResultsType.values()), randomBoolean(), randomInt(), - MEMORY_ACCOUNTING_BUFFER_SIZE + MEMORY_ACCOUNTING_BUFFER_SIZE, + null ) ) { @@ -360,7 +363,8 @@ public ScrollContext scrollContext() { randomFrom(SearchService.ResultsType.values()), randomBoolean(), randomInt(), - MEMORY_ACCOUNTING_BUFFER_SIZE + MEMORY_ACCOUNTING_BUFFER_SIZE, + null ) ) { context3.sliceBuilder(null).parsedQuery(parsedQuery).preProcess(); @@ -392,7 +396,8 @@ public ScrollContext scrollContext() { randomFrom(SearchService.ResultsType.values()), randomBoolean(), randomInt(), - MEMORY_ACCOUNTING_BUFFER_SIZE + MEMORY_ACCOUNTING_BUFFER_SIZE, + null ) ) { context4.sliceBuilder(new SliceBuilder(1, 2)).parsedQuery(parsedQuery).preProcess(); @@ -464,7 +469,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { randomFrom(SearchService.ResultsType.values()), randomBoolean(), randomInt(), - MEMORY_ACCOUNTING_BUFFER_SIZE + MEMORY_ACCOUNTING_BUFFER_SIZE, + null ); assertThat(context.searcher().hasCancellations(), is(false)); @@ -1084,7 +1090,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { randomFrom(SearchService.ResultsType.values()), randomBoolean(), randomInt(), - MEMORY_ACCOUNTING_BUFFER_SIZE + MEMORY_ACCOUNTING_BUFFER_SIZE, + null ); } } diff --git a/server/src/test/java/org/elasticsearch/search/rescore/RescorePhaseTests.java b/server/src/test/java/org/elasticsearch/search/rescore/RescorePhaseTests.java index 5a1c4b789b460..ae0249b260935 100644 --- a/server/src/test/java/org/elasticsearch/search/rescore/RescorePhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/rescore/RescorePhaseTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.index.query.ParsedQuery; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardTestCase; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.search.fetch.subphase.FetchDocValuesContext; import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; import org.elasticsearch.search.internal.ContextIndexSearcher; @@ -78,6 +79,11 @@ public boolean lowLevelCancellation() { return true; } + @Override + public SimpleQueryRewriter simpleQueryRewriter() { + return null; + } + @Override public FetchDocValuesContext docValuesContext() { return context.docValuesContext(); diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java index 8ac52c3d48a0d..11889a8afd166 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.search.SearchExtBuilder; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.SearchContextAggregations; @@ -562,4 +563,9 @@ public SourceLoader newSourceLoader(@Nullable SourceFilter filter) { public IdLoader newIdLoader() { throw new UnsupportedOperationException(); } + + @Override + public SimpleQueryRewriter simpleQueryRewriter() { + return new SimpleQueryRewriter.NoOpSimpleQueryRewriter(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index de31f9d6cefc8..a311ac03e1aae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -43,6 +43,7 @@ import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.plugins.internal.InternalSearchPlugin; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHeaderDefinition; @@ -96,6 +97,7 @@ import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SimpleSemanticQueryRewriter; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; @@ -555,6 +557,11 @@ public List getQueryRewriteInterceptors() { ); } + @Override + public List getSimpleQueryRewriters() { + return List.of(new SimpleSemanticQueryRewriter()); + } + @Override public List> getRetrievers() { return List.of( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SimpleSemanticQueryRewriter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SimpleSemanticQueryRewriter.java new file mode 100644 index 0000000000000..35a7a87577100 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SimpleSemanticQueryRewriter.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.plugins.internal.rewriter.SimpleQueryRewriter; + +public class SimpleSemanticQueryRewriter implements SimpleQueryRewriter { + + public SimpleSemanticQueryRewriter() { + + } + + public String getName() { + return MatchQueryBuilder.NAME; + } + + @Override + public QueryBuilder rewrite(QueryBuilder queryBuilder) { + + if (queryBuilder instanceof MatchQueryBuilder == false) { + // no-op + return queryBuilder; + } + + MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; + return new SemanticQueryBuilder(matchQueryBuilder.fieldName(), matchQueryBuilder.value().toString()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SimpleSemanticQueryRewriterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SimpleSemanticQueryRewriterTests.java new file mode 100644 index 0000000000000..dbca55bb1de23 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SimpleSemanticQueryRewriterTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.index.query; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.elasticsearch.xpack.inference.queries.SimpleSemanticQueryRewriter; + +public class SimpleSemanticQueryRewriterTests extends ESTestCase { + + public void testMatchQueryRewrite() { + MatchQueryBuilder matchQuery = QueryBuilders.matchQuery("field", "value"); + SimpleSemanticQueryRewriter rewriter = new SimpleSemanticQueryRewriter(); + QueryBuilder expected = new SemanticQueryBuilder("field", "value"); + QueryBuilder rewritten = rewriter.rewrite(matchQuery); + assertEquals(rewritten, expected); + } + + public void testNoOpRewrite() { + TermQueryBuilder termQuery = QueryBuilders.termQuery("field", "value"); + SimpleSemanticQueryRewriter rewriter = new SimpleSemanticQueryRewriter(); + QueryBuilder rewritten = rewriter.rewrite(termQuery); + assertEquals(rewritten, termQuery); + } +}