Skip to content

Commit 91f7ea8

Browse files
committed
Merge remote-tracking branch 'upstream/main' into use-lucene-postings-format
2 parents 2637976 + e77bf80 commit 91f7ea8

File tree

19 files changed

+548
-180
lines changed

19 files changed

+548
-180
lines changed

docs/changelog/125694.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125694
2+
summary: LTR score bounding
3+
area: Ranking
4+
type: bug
5+
issues: []

docs/changelog/126002.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126002
2+
summary: Run `TransportGetLifecycleAction` on local node
3+
area: ILM+SLM
4+
type: enhancement
5+
issues: []

muted-tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ tests:
395395
- class: org.elasticsearch.xpack.security.authz.RBACEngineTests
396396
method: testBuildUserPrivilegeResponseCombinesIndexPrivileges
397397
issue: https://github.com/elastic/elasticsearch/issues/126130
398+
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeIT
399+
method: test
400+
issue: https://github.com/elastic/elasticsearch/issues/126139
398401

399402
# Examples:
400403
#

server/src/main/java/org/elasticsearch/action/support/local/LocalClusterStateRequest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionRequest;
14+
import org.elasticsearch.action.ActionRequestValidationException;
1415
import org.elasticsearch.action.support.TransportAction;
1516
import org.elasticsearch.common.io.stream.StreamInput;
1617
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -65,6 +66,11 @@ public final void writeTo(StreamOutput out) throws IOException {
6566
TransportAction.localOnly();
6667
}
6768

69+
@Override
70+
public ActionRequestValidationException validate() {
71+
return null;
72+
}
73+
6874
public TimeValue masterTimeout() {
6975
return masterTimeout;
7076
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/action/GetLifecycleAction.java

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import org.elasticsearch.action.ActionResponse;
1111
import org.elasticsearch.action.ActionType;
12-
import org.elasticsearch.action.support.master.AcknowledgedRequest;
12+
import org.elasticsearch.action.support.local.LocalClusterStateRequest;
1313
import org.elasticsearch.cluster.metadata.ItemUsage;
1414
import org.elasticsearch.common.Strings;
1515
import org.elasticsearch.common.collect.Iterators;
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.common.io.stream.Writeable;
1919
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
2020
import org.elasticsearch.core.TimeValue;
21+
import org.elasticsearch.core.UpdateForV10;
2122
import org.elasticsearch.tasks.CancellableTask;
2223
import org.elasticsearch.tasks.Task;
2324
import org.elasticsearch.tasks.TaskId;
@@ -43,10 +44,6 @@ public static class Response extends ActionResponse implements ChunkedToXContent
4344

4445
private final List<LifecyclePolicyResponseItem> policies;
4546

46-
public Response(StreamInput in) throws IOException {
47-
this.policies = in.readCollectionAsList(LifecyclePolicyResponseItem::new);
48-
}
49-
5047
public Response(List<LifecyclePolicyResponseItem> policies) {
5148
this.policies = policies;
5249
}
@@ -55,6 +52,11 @@ public List<LifecyclePolicyResponseItem> getPolicies() {
5552
return policies;
5653
}
5754

55+
/**
56+
* NB prior to 9.1 this was a TransportMasterNodeAction so for BwC we must remain able to write these responses until
57+
* we no longer need to support calling this action remotely.
58+
*/
59+
@UpdateForV10(owner = UpdateForV10.Owner.DATA_MANAGEMENT)
5860
@Override
5961
public void writeTo(StreamOutput out) throws IOException {
6062
out.writeCollection(policies);
@@ -100,19 +102,26 @@ public Iterator<ToXContent> toXContentChunked(ToXContent.Params outerParams) {
100102
}
101103
}
102104

103-
public static class Request extends AcknowledgedRequest<Request> {
105+
public static class Request extends LocalClusterStateRequest {
104106
private final String[] policyNames;
105107

106-
public Request(TimeValue masterNodeTimeout, TimeValue ackTimeout, String... policyNames) {
107-
super(masterNodeTimeout, ackTimeout);
108+
public Request(TimeValue masterNodeTimeout, String... policyNames) {
109+
super(masterNodeTimeout);
108110
if (policyNames == null) {
109111
throw new IllegalArgumentException("ids cannot be null");
110112
}
111113
this.policyNames = policyNames;
112114
}
113115

116+
/**
117+
* NB prior to 9.1 this was a TransportMasterNodeAction so for BwC we must remain able to read these requests until
118+
* we no longer need to support calling this action remotely.
119+
*/
120+
@UpdateForV10(owner = UpdateForV10.Owner.DATA_MANAGEMENT)
114121
public Request(StreamInput in) throws IOException {
115-
super(in);
122+
super(in, false);
123+
// This used to be an AcknowledgedRequest so we need to read the ack timeout for BwC.
124+
in.readTimeValue();
116125
policyNames = in.readStringArray();
117126
}
118127

@@ -125,12 +134,6 @@ public String[] getPolicyNames() {
125134
return policyNames;
126135
}
127136

128-
@Override
129-
public void writeTo(StreamOutput out) throws IOException {
130-
super.writeTo(out);
131-
out.writeStringArray(policyNames);
132-
}
133-
134137
@Override
135138
public int hashCode() {
136139
return Arrays.hashCode(policyNames);
@@ -163,13 +166,11 @@ public LifecyclePolicyResponseItem(LifecyclePolicy lifecyclePolicy, long version
163166
this.usage = usage;
164167
}
165168

166-
LifecyclePolicyResponseItem(StreamInput in) throws IOException {
167-
this.lifecyclePolicy = new LifecyclePolicy(in);
168-
this.version = in.readVLong();
169-
this.modifiedDate = in.readString();
170-
this.usage = new ItemUsage(in);
171-
}
172-
169+
/**
170+
* NB prior to 9.1 this was a TransportMasterNodeAction so for BwC we must remain able to write these responses until
171+
* we no longer need to support calling this action remotely.
172+
*/
173+
@UpdateForV10(owner = UpdateForV10.Owner.DATA_MANAGEMENT)
173174
@Override
174175
public void writeTo(StreamOutput out) throws IOException {
175176
lifecyclePolicy.writeTo(out);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
9+
10+
public interface BoundedInferenceModel extends InferenceModel {
11+
double getMinPredictedValue();
12+
13+
double getMaxPredictedValue();
14+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
9+
10+
import org.elasticsearch.common.logging.LoggerMessageFormat;
11+
import org.elasticsearch.inference.InferenceResults;
12+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
13+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
14+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
15+
16+
import java.util.Map;
17+
18+
public class BoundedWindowInferenceModel implements BoundedInferenceModel {
19+
public static final double DEFAULT_MIN_PREDICTED_VALUE = 0;
20+
21+
private final BoundedInferenceModel model;
22+
private final double minPredictedValue;
23+
private final double maxPredictedValue;
24+
private final double adjustmentValue;
25+
26+
public BoundedWindowInferenceModel(BoundedInferenceModel model) {
27+
this.model = model;
28+
this.minPredictedValue = model.getMinPredictedValue();
29+
this.maxPredictedValue = model.getMaxPredictedValue();
30+
31+
if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) {
32+
this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue;
33+
} else {
34+
this.adjustmentValue = 0.0;
35+
}
36+
}
37+
38+
@Override
39+
public String[] getFeatureNames() {
40+
return model.getFeatureNames();
41+
}
42+
43+
@Override
44+
public TargetType targetType() {
45+
return model.targetType();
46+
}
47+
48+
@Override
49+
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
50+
return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap));
51+
}
52+
53+
@Override
54+
public InferenceResults infer(double[] features, InferenceConfig config) {
55+
return boundInferenceResultScores(model.infer(features, config));
56+
}
57+
58+
@Override
59+
public boolean supportsFeatureImportance() {
60+
return model.supportsFeatureImportance();
61+
}
62+
63+
@Override
64+
public String getName() {
65+
return "bounded_window[" + model.getName() + "]";
66+
}
67+
68+
@Override
69+
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
70+
model.rewriteFeatureIndices(newFeatureIndexMapping);
71+
}
72+
73+
@Override
74+
public long ramBytesUsed() {
75+
return model.ramBytesUsed();
76+
}
77+
78+
@Override
79+
public double getMinPredictedValue() {
80+
return minPredictedValue;
81+
}
82+
83+
@Override
84+
public double getMaxPredictedValue() {
85+
return maxPredictedValue;
86+
}
87+
88+
private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) {
89+
// if the min value < the default minimum, slide the values up by the adjustment value
90+
if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) {
91+
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();
92+
93+
predictedValue += this.adjustmentValue;
94+
95+
return new RegressionInferenceResults(
96+
predictedValue,
97+
inferenceResult.getResultsField(),
98+
((RegressionInferenceResults) inferenceResult).getFeatureImportance()
99+
);
100+
}
101+
102+
throw new IllegalStateException(
103+
LoggerMessageFormat.format(
104+
"Model used within a {} should return a {} but got {} instead",
105+
BoundedWindowInferenceModel.class.getSimpleName(),
106+
RegressionInferenceResults.class.getSimpleName(),
107+
inferenceResult.getClass().getSimpleName()
108+
)
109+
);
110+
}
111+
112+
@Override
113+
public String toString() {
114+
return "BoundedWindowInferenceModel{"
115+
+ "model="
116+
+ model
117+
+ ", minPredictedValue="
118+
+ getMinPredictedValue()
119+
+ ", maxPredictedValue="
120+
+ getMaxPredictedValue()
121+
+ '}';
122+
}
123+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.RamUsageEstimator;
1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.util.CachedSupplier;
1415
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.core.Tuple;
1617
import org.elasticsearch.inference.InferenceResults;
@@ -36,6 +37,7 @@
3637
import java.util.List;
3738
import java.util.Map;
3839
import java.util.Set;
40+
import java.util.function.Supplier;
3941
import java.util.stream.Collectors;
4042
import java.util.stream.IntStream;
4143

@@ -52,7 +54,7 @@
5254
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
5355
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
5456

55-
public class EnsembleInferenceModel implements InferenceModel {
57+
public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel {
5658

5759
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
5860
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
@@ -97,6 +99,7 @@ public static EnsembleInferenceModel fromXContent(XContentParser parser) {
9799
private final List<String> classificationLabels;
98100
private final double[] classificationWeights;
99101
private volatile boolean preparedForInference = false;
102+
private final Supplier<double[]> predictedValuesBoundariesSupplier;
100103

101104
private EnsembleInferenceModel(
102105
List<InferenceModel> models,
@@ -112,6 +115,7 @@ private EnsembleInferenceModel(
112115
this.classificationWeights = classificationWeights == null
113116
? null
114117
: classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
118+
this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
115119
}
116120

117121
@Override
@@ -328,21 +332,57 @@ public double[] getClassificationWeights() {
328332

329333
@Override
330334
public String toString() {
331-
return "EnsembleInferenceModel{"
332-
+ "featureNames="
333-
+ Arrays.toString(featureNames)
334-
+ ", models="
335-
+ models
336-
+ ", outputAggregator="
337-
+ outputAggregator
338-
+ ", targetType="
339-
+ targetType
340-
+ ", classificationLabels="
341-
+ classificationLabels
342-
+ ", classificationWeights="
343-
+ Arrays.toString(classificationWeights)
344-
+ ", preparedForInference="
345-
+ preparedForInference
346-
+ '}';
335+
StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");
336+
337+
builder.append("featureNames=")
338+
.append(Arrays.toString(featureNames))
339+
.append(", models=")
340+
.append(models)
341+
.append(", outputAggregator=")
342+
.append(outputAggregator)
343+
.append(", targetType=")
344+
.append(targetType);
345+
346+
if (targetType == TargetType.CLASSIFICATION) {
347+
builder.append(", classificationLabels=")
348+
.append(classificationLabels)
349+
.append(", classificationWeights=")
350+
.append(Arrays.toString(classificationWeights));
351+
} else if (targetType == TargetType.REGRESSION) {
352+
builder.append(", minPredictedValue=")
353+
.append(getMinPredictedValue())
354+
.append(", maxPredictedValue=")
355+
.append(getMaxPredictedValue());
356+
}
357+
358+
builder.append(", preparedForInference=").append(preparedForInference);
359+
360+
return builder.append('}').toString();
361+
}
362+
363+
@Override
364+
public double getMinPredictedValue() {
365+
return this.predictedValuesBoundariesSupplier.get()[0];
366+
}
367+
368+
@Override
369+
public double getMaxPredictedValue() {
370+
return this.predictedValuesBoundariesSupplier.get()[1];
371+
}
372+
373+
private double[] initModelBoundaries() {
374+
double[] modelsMinBoundaries = new double[models.size()];
375+
double[] modelsMaxBoundaries = new double[models.size()];
376+
int i = 0;
377+
for (InferenceModel model : models) {
378+
if (model instanceof BoundedInferenceModel boundedInferenceModel) {
379+
modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
380+
modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
381+
} else {
382+
throw new IllegalStateException("All submodels have to be bounded");
383+
}
384+
}
385+
386+
return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) };
347387
}
348388
}

0 commit comments

Comments
 (0)