Skip to content

Commit 8a61676

Browse files
jan-elasticalbertzaharovits
authored andcommitted
Fix put inference endpoint with adaptive allocations (#110640)
* Fix put inference endpoint with adaptive allocations * Better validation of adaptive allocations settings in inference endpoints * Safeguard for max allocations * remove debug code
1 parent c1f97f4 commit 8a61676

File tree

12 files changed

+160
-66
lines changed

12 files changed

+160
-66
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.ActionRequestValidationException;
1213
import org.elasticsearch.common.ValidationException;
1314
import org.elasticsearch.common.settings.SecureString;
1415
import org.elasticsearch.core.Nullable;
@@ -131,18 +132,31 @@ public static Object removeAsOneOfTypes(
131132
return null;
132133
}
133134

134-
public static AdaptiveAllocationsSettings removeAsAdaptiveAllocationsSettings(Map<String, Object> sourceMap, String key) {
135+
public static AdaptiveAllocationsSettings removeAsAdaptiveAllocationsSettings(
136+
Map<String, Object> sourceMap,
137+
String key,
138+
ValidationException validationException
139+
) {
135140
if (AdaptiveAllocationsFeatureFlag.isEnabled() == false) {
136141
return null;
137142
}
138143
Map<String, Object> settingsMap = ServiceUtils.removeFromMap(sourceMap, key);
139-
return settingsMap == null
140-
? null
141-
: new AdaptiveAllocationsSettings(
142-
ServiceUtils.removeAsType(settingsMap, ENABLED.getPreferredName(), Boolean.class),
143-
ServiceUtils.removeAsType(settingsMap, MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class),
144-
ServiceUtils.removeAsType(settingsMap, MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class)
145-
);
144+
if (settingsMap == null) {
145+
return null;
146+
}
147+
AdaptiveAllocationsSettings settings = new AdaptiveAllocationsSettings(
148+
ServiceUtils.removeAsType(settingsMap, ENABLED.getPreferredName(), Boolean.class, validationException),
149+
ServiceUtils.removeAsType(settingsMap, MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class, validationException),
150+
ServiceUtils.removeAsType(settingsMap, MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class, validationException)
151+
);
152+
for (String settingName : settingsMap.keySet()) {
153+
validationException.addValidationError(invalidSettingError(settingName, key));
154+
}
155+
ActionRequestValidationException exception = settings.validate();
156+
if (exception != null) {
157+
validationException.addValidationErrors(exception.validationErrors());
158+
}
159+
return settings;
146160
}
147161

148162
@SuppressWarnings("unchecked")
@@ -196,6 +210,10 @@ public static String missingSettingErrorMsg(String settingName, String scope) {
196210
return Strings.format("[%s] does not contain the required setting [%s]", scope, settingName);
197211
}
198212

213+
public static String missingOneOfSettingsErrorMsg(List<String> settingNames, String scope) {
214+
return Strings.format("[%s] does not contain one of the required settings [%s]", scope, String.join(", ", settingNames));
215+
}
216+
199217
public static String invalidTypeErrorMsg(String settingName, Object foundObject, String expectedType) {
200218
return Strings.format(
201219
"field [%s] is not of the expected type. The value [%s] cannot be converted to a [%s]",

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12-
import org.elasticsearch.action.ActionRequestValidationException;
1312
import org.elasticsearch.common.ValidationException;
1413
import org.elasticsearch.common.io.stream.StreamInput;
1514
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -21,6 +20,7 @@
2120
import java.io.IOException;
2221
import java.util.Map;
2322

23+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
2424
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger;
2525
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
2626

@@ -29,7 +29,7 @@ public class CustomElandInternalServiceSettings extends ElasticsearchInternalSer
2929
public static final String NAME = "custom_eland_model_internal_service_settings";
3030

3131
public CustomElandInternalServiceSettings(
32-
int numAllocations,
32+
Integer numAllocations,
3333
int numThreads,
3434
String modelId,
3535
AdaptiveAllocationsSettings adaptiveAllocationsSettings
@@ -51,7 +51,7 @@ public CustomElandInternalServiceSettings(
5151
public static CustomElandInternalServiceSettings fromMap(Map<String, Object> map) {
5252
ValidationException validationException = new ValidationException();
5353

54-
Integer numAllocations = extractRequiredPositiveInteger(
54+
Integer numAllocations = extractOptionalPositiveInteger(
5555
map,
5656
NUM_ALLOCATIONS,
5757
ModelConfigurations.SERVICE_SETTINGS,
@@ -60,14 +60,9 @@ public static CustomElandInternalServiceSettings fromMap(Map<String, Object> map
6060
Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException);
6161
AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings(
6262
map,
63-
ADAPTIVE_ALLOCATIONS
63+
ADAPTIVE_ALLOCATIONS,
64+
validationException
6465
);
65-
if (adaptiveAllocationsSettings != null) {
66-
ActionRequestValidationException exception = adaptiveAllocationsSettings.validate();
67-
if (exception != null) {
68-
validationException.addValidationErrors(exception.validationErrors());
69-
}
70-
}
7166
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
7267

7368
if (validationException.validationErrors().isEmpty() == false) {
@@ -99,7 +94,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
9994

10095
public CustomElandInternalServiceSettings(StreamInput in) throws IOException {
10196
super(
102-
in.readVInt(),
97+
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) ? in.readOptionalVInt() : in.readVInt(),
10398
in.readVInt(),
10499
in.readString(),
105100
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12-
import org.elasticsearch.action.ActionRequestValidationException;
1312
import org.elasticsearch.common.ValidationException;
1413
import org.elasticsearch.common.io.stream.StreamInput;
1514
import org.elasticsearch.inference.ModelConfigurations;
@@ -21,6 +20,7 @@
2120
import java.util.Map;
2221
import java.util.Objects;
2322

23+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
2424
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger;
2525
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
2626

@@ -30,7 +30,7 @@ public class ElasticsearchInternalServiceSettings extends InternalServiceSetting
3030
private static final int FAILED_INT_PARSE_VALUE = -1;
3131

3232
public static ElasticsearchInternalServiceSettings fromMap(Map<String, Object> map, ValidationException validationException) {
33-
Integer numAllocations = extractRequiredPositiveInteger(
33+
Integer numAllocations = extractOptionalPositiveInteger(
3434
map,
3535
NUM_ALLOCATIONS,
3636
ModelConfigurations.SERVICE_SETTINGS,
@@ -39,28 +39,23 @@ public static ElasticsearchInternalServiceSettings fromMap(Map<String, Object> m
3939
Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException);
4040
AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings(
4141
map,
42-
ADAPTIVE_ALLOCATIONS
42+
ADAPTIVE_ALLOCATIONS,
43+
validationException
4344
);
44-
if (adaptiveAllocationsSettings != null) {
45-
ActionRequestValidationException exception = adaptiveAllocationsSettings.validate();
46-
if (exception != null) {
47-
validationException.addValidationErrors(exception.validationErrors());
48-
}
49-
}
5045
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
5146

5247
// if an error occurred while parsing, we'll set these to an invalid value, so we don't accidentally get a
5348
// null pointer when doing unboxing
5449
return new ElasticsearchInternalServiceSettings(
55-
Objects.requireNonNullElse(numAllocations, FAILED_INT_PARSE_VALUE),
50+
numAllocations,
5651
Objects.requireNonNullElse(numThreads, FAILED_INT_PARSE_VALUE),
5752
modelId,
5853
adaptiveAllocationsSettings
5954
);
6055
}
6156

6257
public ElasticsearchInternalServiceSettings(
63-
int numAllocations,
58+
Integer numAllocations,
6459
int numThreads,
6560
String modelVariant,
6661
AdaptiveAllocationsSettings adaptiveAllocationsSettings
@@ -70,7 +65,7 @@ public ElasticsearchInternalServiceSettings(
7065

7166
public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
7267
super(
73-
in.readVInt(),
68+
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) ? in.readOptionalVInt() : in.readVInt(),
7469
in.readVInt(),
7570
in.readString(),
7671
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

1010
import org.elasticsearch.TransportVersions;
11-
import org.elasticsearch.action.ActionRequestValidationException;
1211
import org.elasticsearch.common.ValidationException;
1312
import org.elasticsearch.common.io.stream.StreamInput;
1413
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -22,8 +21,10 @@
2221

2322
import java.io.IOException;
2423
import java.util.Arrays;
24+
import java.util.List;
2525
import java.util.Map;
2626

27+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
2728
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger;
2829

2930
public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInternalServiceSettings {
@@ -34,7 +35,7 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt
3435
static final SimilarityMeasure SIMILARITY = SimilarityMeasure.COSINE;
3536

3637
public MultilingualE5SmallInternalServiceSettings(
37-
int numAllocations,
38+
Integer numAllocations,
3839
int numThreads,
3940
String modelId,
4041
AdaptiveAllocationsSettings adaptiveAllocationsSettings
@@ -44,7 +45,7 @@ public MultilingualE5SmallInternalServiceSettings(
4445

4546
public MultilingualE5SmallInternalServiceSettings(StreamInput in) throws IOException {
4647
super(
47-
in.readVInt(),
48+
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) ? in.readOptionalVInt() : in.readVInt(),
4849
in.readVInt(),
4950
in.readString(),
5051
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)
@@ -74,7 +75,7 @@ public static MultilingualE5SmallInternalServiceSettings.Builder fromMap(Map<Str
7475
}
7576

7677
private static RequestFields extractRequestFields(Map<String, Object> map, ValidationException validationException) {
77-
Integer numAllocations = extractRequiredPositiveInteger(
78+
Integer numAllocations = extractOptionalPositiveInteger(
7879
map,
7980
NUM_ALLOCATIONS,
8081
ModelConfigurations.SERVICE_SETTINGS,
@@ -83,13 +84,16 @@ private static RequestFields extractRequestFields(Map<String, Object> map, Valid
8384
Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException);
8485
AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings(
8586
map,
86-
ADAPTIVE_ALLOCATIONS
87+
ADAPTIVE_ALLOCATIONS,
88+
validationException
8789
);
88-
if (adaptiveAllocationsSettings != null) {
89-
ActionRequestValidationException exception = adaptiveAllocationsSettings.validate();
90-
if (exception != null) {
91-
validationException.addValidationErrors(exception.validationErrors());
92-
}
90+
if (numAllocations == null && adaptiveAllocationsSettings == null) {
91+
validationException.addValidationError(
92+
ServiceUtils.missingOneOfSettingsErrorMsg(
93+
List.of(NUM_ALLOCATIONS, ADAPTIVE_ALLOCATIONS),
94+
ModelConfigurations.SERVICE_SETTINGS
95+
)
96+
);
9397
}
9498
String modelId = ServiceUtils.removeAsType(map, MODEL_ID, String.class);
9599
if (modelId != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12-
import org.elasticsearch.action.ActionRequestValidationException;
1312
import org.elasticsearch.common.ValidationException;
1413
import org.elasticsearch.common.io.stream.StreamInput;
1514
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -19,9 +18,11 @@
1918
import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;
2019

2120
import java.io.IOException;
21+
import java.util.List;
2222
import java.util.Map;
2323
import java.util.Objects;
2424

25+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
2526
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
2627
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger;
2728

@@ -41,7 +42,7 @@ public class ElserInternalServiceSettings extends InternalServiceSettings {
4142
public static ElserInternalServiceSettings.Builder fromMap(Map<String, Object> map) {
4243
ValidationException validationException = new ValidationException();
4344

44-
Integer numAllocations = extractRequiredPositiveInteger(
45+
Integer numAllocations = extractOptionalPositiveInteger(
4546
map,
4647
NUM_ALLOCATIONS,
4748
ModelConfigurations.SERVICE_SETTINGS,
@@ -50,13 +51,16 @@ public static ElserInternalServiceSettings.Builder fromMap(Map<String, Object> m
5051
Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException);
5152
AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings(
5253
map,
53-
ADAPTIVE_ALLOCATIONS
54+
ADAPTIVE_ALLOCATIONS,
55+
validationException
5456
);
55-
if (adaptiveAllocationsSettings != null) {
56-
ActionRequestValidationException exception = adaptiveAllocationsSettings.validate();
57-
if (exception != null) {
58-
validationException.addValidationErrors(exception.validationErrors());
59-
}
57+
if (numAllocations == null && adaptiveAllocationsSettings == null) {
58+
validationException.addValidationError(
59+
ServiceUtils.missingOneOfSettingsErrorMsg(
60+
List.of(NUM_ALLOCATIONS, ADAPTIVE_ALLOCATIONS),
61+
ModelConfigurations.SERVICE_SETTINGS
62+
)
63+
);
6064
}
6165
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
6266

@@ -87,7 +91,7 @@ public ElserInternalServiceSettings build() {
8791
}
8892

8993
public ElserInternalServiceSettings(
90-
int numAllocations,
94+
Integer numAllocations,
9195
int numThreads,
9296
String modelId,
9397
AdaptiveAllocationsSettings adaptiveAllocationsSettings
@@ -98,7 +102,7 @@ public ElserInternalServiceSettings(
98102

99103
public ElserInternalServiceSettings(StreamInput in) throws IOException {
100104
super(
101-
in.readVInt(),
105+
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) ? in.readOptionalVInt() : in.readVInt(),
102106
in.readVInt(),
103107
in.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X) ? in.readString() : ElserInternalService.ELSER_V2_MODEL,
104108
in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)
@@ -119,7 +123,11 @@ public TransportVersion getMinimalSupportedVersion() {
119123

120124
@Override
121125
public void writeTo(StreamOutput out) throws IOException {
122-
out.writeVInt(getNumAllocations());
126+
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) {
127+
out.writeOptionalVInt(getNumAllocations());
128+
} else {
129+
out.writeVInt(getNumAllocations());
130+
}
123131
out.writeVInt(getNumThreads());
124132
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X)) {
125133
out.writeString(getModelId());

0 commit comments

Comments
 (0)