Skip to content

Commit 5c0b35d

Browse files
Merge branch 'main' of github.com:elastic/elasticsearch into inference-eis-acl
2 parents 1ca8b32 + 2e836cf commit 5c0b35d

File tree

48 files changed

+1320
-459
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1320
-459
lines changed

docs/reference/inference/update-inference.asciidoc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ However, if you do not plan to use the {infer} APIs to use these models or if yo
1919
[[update-inference-api-request]]
2020
==== {api-request-title}
2121

22-
`POST _inference/<inference_id>/_update`
22+
`PUT _inference/<inference_id>/_update`
2323

24-
`POST _inference/<task_type>/<inference_id>/_update`
24+
`PUT _inference/<task_type>/<inference_id>/_update`
2525

2626

2727
[discrete]
@@ -52,7 +52,7 @@ Click the links to review the service configuration details:
5252
* <<infer-service-elasticsearch,Elasticsearch>> (`rerank`, `sparse_embedding`, `text_embedding` - this service is for built-in models and models uploaded through Eland)
5353
* <<infer-service-elser,ELSER>> (`sparse_embedding`)
5454
* <<infer-service-google-ai-studio,Google AI Studio>> (`completion`, `text_embedding`)
55-
* <<infer-service-google-vertex-ai,Google Vertex AI>> (`rerank`, `text_embedding`)
55+
* <<infer-service-google-vertex-ai,Google Vertex AI>> (`rerank`, `text_embedding`)
5656
* <<infer-service-hugging-face,Hugging Face>> (`text_embedding`)
5757
* <<infer-service-mistral,Mistral>> (`text_embedding`)
5858
* <<infer-service-openai,OpenAI>> (`completion`, `text_embedding`)
@@ -81,7 +81,7 @@ The following example shows how to update an API key of an {infer} endpoint call
8181

8282
[source,console]
8383
------------------------------------------------------------
84-
POST _inference/my-inference-endpoint/_update
84+
PUT _inference/my-inference-endpoint/_update
8585
{
8686
"service_settings": {
8787
"api_key": "<API_KEY>"

muted-tests.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,6 @@ tests:
224224
- class: org.elasticsearch.search.profile.dfs.DfsProfilerIT
225225
method: testProfileDfs
226226
issue: https://github.com/elastic/elasticsearch/issues/119711
227-
- class: org.elasticsearch.xpack.inference.InferenceCrudIT
228-
method: testGetServicesWithCompletionTaskType
229-
issue: https://github.com/elastic/elasticsearch/issues/119959
230227
- class: org.elasticsearch.multi_cluster.MultiClusterYamlTestSuiteIT
231228
issue: https://github.com/elastic/elasticsearch/issues/119983
232229
- class: org.elasticsearch.xpack.test.rest.XPackRestIT

server/src/main/java/org/elasticsearch/action/admin/cluster/node/capabilities/NodeCapability.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,9 @@ public void writeTo(StreamOutput out) throws IOException {
4141

4242
out.writeBoolean(supported);
4343
}
44+
45+
@Override
46+
public String toString() {
47+
return "NodeCapability{supported=" + supported + '}';
48+
}
4449
}

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,30 @@
1010
package org.elasticsearch.action.search;
1111

1212
import org.apache.logging.log4j.Logger;
13+
import org.apache.lucene.index.Term;
14+
import org.apache.lucene.search.CollectionStatistics;
15+
import org.apache.lucene.search.ScoreDoc;
16+
import org.apache.lucene.search.TermStatistics;
17+
import org.apache.lucene.search.TopDocs;
18+
import org.apache.lucene.search.TotalHits;
19+
import org.apache.lucene.util.SetOnce;
1320
import org.elasticsearch.action.ActionListener;
1421
import org.elasticsearch.client.internal.Client;
1522
import org.elasticsearch.cluster.ClusterState;
1623
import org.elasticsearch.cluster.routing.GroupShardsIterator;
1724
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1825
import org.elasticsearch.search.SearchPhaseResult;
1926
import org.elasticsearch.search.SearchShardTarget;
27+
import org.elasticsearch.search.builder.SearchSourceBuilder;
2028
import org.elasticsearch.search.dfs.AggregatedDfs;
2129
import org.elasticsearch.search.dfs.DfsKnnResults;
2230
import org.elasticsearch.search.dfs.DfsSearchResult;
2331
import org.elasticsearch.search.internal.AliasFilter;
2432
import org.elasticsearch.transport.Transport;
2533

34+
import java.util.ArrayList;
35+
import java.util.Collection;
36+
import java.util.HashMap;
2637
import java.util.List;
2738
import java.util.Map;
2839
import java.util.concurrent.Executor;
@@ -93,12 +104,11 @@ protected void executePhaseOnShard(
93104
@Override
94105
protected SearchPhase getNextPhase() {
95106
final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList();
96-
final AggregatedDfs aggregatedDfs = SearchPhaseController.aggregateDfs(dfsSearchResults);
97-
final List<DfsKnnResults> mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults);
107+
final AggregatedDfs aggregatedDfs = aggregateDfs(dfsSearchResults);
98108
return new DfsQueryPhase(
99109
dfsSearchResults,
100110
aggregatedDfs,
101-
mergedKnnResults,
111+
mergeKnnResults(getRequest(), dfsSearchResults),
102112
queryPhaseResultConsumer,
103113
(queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, this, queryResults, aggregatedDfs),
104114
this
@@ -109,4 +119,95 @@ protected SearchPhase getNextPhase() {
109119
protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
110120
progressListener.notifyQueryFailure(shardIndex, shardTarget, exc);
111121
}
122+
123+
private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
124+
if (request.hasKnnSearch() == false) {
125+
return null;
126+
}
127+
SearchSourceBuilder source = request.source();
128+
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
129+
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
130+
for (int i = 0; i < source.knnSearch().size(); i++) {
131+
topDocsLists.add(new ArrayList<>());
132+
nestedPath.add(new SetOnce<>());
133+
}
134+
135+
for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
136+
if (dfsSearchResult.knnResults() != null) {
137+
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
138+
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
139+
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
140+
TotalHits totalHits = new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO);
141+
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
142+
SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
143+
topDocsLists.get(i).add(shardTopDocs);
144+
nestedPath.get(i).trySet(knnResults.getNestedPath());
145+
}
146+
}
147+
}
148+
149+
List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
150+
for (int i = 0; i < source.knnSearch().size(); i++) {
151+
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
152+
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
153+
}
154+
return mergedResults;
155+
}
156+
157+
private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
158+
Map<Term, TermStatistics> termStatistics = new HashMap<>();
159+
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
160+
long aggMaxDoc = 0;
161+
for (DfsSearchResult lEntry : results) {
162+
final Term[] terms = lEntry.terms();
163+
final TermStatistics[] stats = lEntry.termStatistics();
164+
assert terms.length == stats.length;
165+
for (int i = 0; i < terms.length; i++) {
166+
assert terms[i] != null;
167+
if (stats[i] == null) {
168+
continue;
169+
}
170+
TermStatistics existing = termStatistics.get(terms[i]);
171+
if (existing != null) {
172+
assert terms[i].bytes().equals(existing.term());
173+
termStatistics.put(
174+
terms[i],
175+
new TermStatistics(
176+
existing.term(),
177+
existing.docFreq() + stats[i].docFreq(),
178+
existing.totalTermFreq() + stats[i].totalTermFreq()
179+
)
180+
);
181+
} else {
182+
termStatistics.put(terms[i], stats[i]);
183+
}
184+
185+
}
186+
187+
assert lEntry.fieldStatistics().containsKey(null) == false;
188+
for (var entry : lEntry.fieldStatistics().entrySet()) {
189+
String key = entry.getKey();
190+
CollectionStatistics value = entry.getValue();
191+
if (value == null) {
192+
continue;
193+
}
194+
assert key != null;
195+
CollectionStatistics existing = fieldStatistics.get(key);
196+
if (existing != null) {
197+
CollectionStatistics merged = new CollectionStatistics(
198+
key,
199+
existing.maxDoc() + value.maxDoc(),
200+
existing.docCount() + value.docCount(),
201+
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
202+
existing.sumDocFreq() + value.sumDocFreq()
203+
);
204+
fieldStatistics.put(key, merged);
205+
} else {
206+
fieldStatistics.put(key, value);
207+
}
208+
}
209+
aggMaxDoc += lEntry.maxDoc();
210+
}
211+
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
212+
}
112213
}

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,16 @@
99

1010
package org.elasticsearch.action.search;
1111

12-
import org.apache.lucene.index.Term;
13-
import org.apache.lucene.search.CollectionStatistics;
1412
import org.apache.lucene.search.FieldDoc;
1513
import org.apache.lucene.search.ScoreDoc;
1614
import org.apache.lucene.search.Sort;
1715
import org.apache.lucene.search.SortField;
1816
import org.apache.lucene.search.SortedNumericSortField;
1917
import org.apache.lucene.search.SortedSetSortField;
20-
import org.apache.lucene.search.TermStatistics;
2118
import org.apache.lucene.search.TopDocs;
2219
import org.apache.lucene.search.TopFieldDocs;
2320
import org.apache.lucene.search.TotalHits;
2421
import org.apache.lucene.search.TotalHits.Relation;
25-
import org.apache.lucene.util.SetOnce;
2622
import org.elasticsearch.common.breaker.CircuitBreaker;
2723
import org.elasticsearch.common.io.stream.DelayableWriteable;
2824
import org.elasticsearch.common.lucene.Lucene;
@@ -42,9 +38,6 @@
4238
import org.elasticsearch.search.aggregations.AggregatorFactories;
4339
import org.elasticsearch.search.aggregations.InternalAggregations;
4440
import org.elasticsearch.search.builder.SearchSourceBuilder;
45-
import org.elasticsearch.search.dfs.AggregatedDfs;
46-
import org.elasticsearch.search.dfs.DfsKnnResults;
47-
import org.elasticsearch.search.dfs.DfsSearchResult;
4841
import org.elasticsearch.search.fetch.FetchSearchResult;
4942
import org.elasticsearch.search.internal.SearchContext;
5043
import org.elasticsearch.search.profile.SearchProfileQueryPhaseResult;
@@ -84,97 +77,6 @@ public SearchPhaseController(
8477
this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
8578
}
8679

87-
public static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
88-
Map<Term, TermStatistics> termStatistics = new HashMap<>();
89-
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
90-
long aggMaxDoc = 0;
91-
for (DfsSearchResult lEntry : results) {
92-
final Term[] terms = lEntry.terms();
93-
final TermStatistics[] stats = lEntry.termStatistics();
94-
assert terms.length == stats.length;
95-
for (int i = 0; i < terms.length; i++) {
96-
assert terms[i] != null;
97-
if (stats[i] == null) {
98-
continue;
99-
}
100-
TermStatistics existing = termStatistics.get(terms[i]);
101-
if (existing != null) {
102-
assert terms[i].bytes().equals(existing.term());
103-
termStatistics.put(
104-
terms[i],
105-
new TermStatistics(
106-
existing.term(),
107-
existing.docFreq() + stats[i].docFreq(),
108-
existing.totalTermFreq() + stats[i].totalTermFreq()
109-
)
110-
);
111-
} else {
112-
termStatistics.put(terms[i], stats[i]);
113-
}
114-
115-
}
116-
117-
assert lEntry.fieldStatistics().containsKey(null) == false;
118-
for (var entry : lEntry.fieldStatistics().entrySet()) {
119-
String key = entry.getKey();
120-
CollectionStatistics value = entry.getValue();
121-
if (value == null) {
122-
continue;
123-
}
124-
assert key != null;
125-
CollectionStatistics existing = fieldStatistics.get(key);
126-
if (existing != null) {
127-
CollectionStatistics merged = new CollectionStatistics(
128-
key,
129-
existing.maxDoc() + value.maxDoc(),
130-
existing.docCount() + value.docCount(),
131-
existing.sumTotalTermFreq() + value.sumTotalTermFreq(),
132-
existing.sumDocFreq() + value.sumDocFreq()
133-
);
134-
fieldStatistics.put(key, merged);
135-
} else {
136-
fieldStatistics.put(key, value);
137-
}
138-
}
139-
aggMaxDoc += lEntry.maxDoc();
140-
}
141-
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
142-
}
143-
144-
public static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
145-
if (request.hasKnnSearch() == false) {
146-
return null;
147-
}
148-
SearchSourceBuilder source = request.source();
149-
List<List<TopDocs>> topDocsLists = new ArrayList<>(source.knnSearch().size());
150-
List<SetOnce<String>> nestedPath = new ArrayList<>(source.knnSearch().size());
151-
for (int i = 0; i < source.knnSearch().size(); i++) {
152-
topDocsLists.add(new ArrayList<>());
153-
nestedPath.add(new SetOnce<>());
154-
}
155-
156-
for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
157-
if (dfsSearchResult.knnResults() != null) {
158-
for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
159-
DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
160-
ScoreDoc[] scoreDocs = knnResults.scoreDocs();
161-
TotalHits totalHits = new TotalHits(scoreDocs.length, Relation.EQUAL_TO);
162-
TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
163-
setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
164-
topDocsLists.get(i).add(shardTopDocs);
165-
nestedPath.get(i).trySet(knnResults.getNestedPath());
166-
}
167-
}
168-
}
169-
170-
List<DfsKnnResults> mergedResults = new ArrayList<>(source.knnSearch().size());
171-
for (int i = 0; i < source.knnSearch().size(); i++) {
172-
TopDocs mergedTopDocs = TopDocs.merge(source.knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
173-
mergedResults.add(new DfsKnnResults(nestedPath.get(i).get(), mergedTopDocs.scoreDocs));
174-
}
175-
return mergedResults;
176-
}
177-
17880
/**
17981
* Returns a score doc array of top N search docs across all shards, followed by top suggest docs for each
18082
* named completion suggestion across all shards. If more than one named completion suggestion is specified in the
@@ -496,38 +398,6 @@ private static SearchHits getHits(
496398
);
497399
}
498400

499-
/**
500-
* Reduces the given query results and consumes all aggregations and profile results.
501-
* @param queryResults a list of non-null query shard results
502-
*/
503-
static ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
504-
AggregationReduceContext.Builder aggReduceContextBuilder = new AggregationReduceContext.Builder() {
505-
@Override
506-
public AggregationReduceContext forPartialReduction() {
507-
throw new UnsupportedOperationException("Scroll requests don't have aggs");
508-
}
509-
510-
@Override
511-
public AggregationReduceContext forFinalReduction() {
512-
throw new UnsupportedOperationException("Scroll requests don't have aggs");
513-
}
514-
};
515-
final TopDocsStats topDocsStats = new TopDocsStats(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
516-
final List<TopDocs> topDocs = new ArrayList<>();
517-
for (SearchPhaseResult sortedResult : queryResults) {
518-
QuerySearchResult queryResult = sortedResult.queryResult();
519-
final TopDocsAndMaxScore td = queryResult.consumeTopDocs();
520-
assert td != null;
521-
topDocsStats.add(td, queryResult.searchTimedOut(), queryResult.terminatedEarly());
522-
// make sure we set the shard index before we add it - the consumer didn't do that yet
523-
if (td.topDocs.scoreDocs.length > 0) {
524-
setShardIndex(td.topDocs, queryResult.getShardIndex());
525-
topDocs.add(td.topDocs);
526-
}
527-
}
528-
return reducedQueryPhase(queryResults, null, topDocs, topDocsStats, 0, true, aggReduceContextBuilder, null, true);
529-
}
530-
531401
/**
532402
* Reduces the given query results and consumes all aggregations and profile results.
533403
* @param queryResults a list of non-null query shard results

0 commit comments

Comments
 (0)