Skip to content

Commit 4ef37f5

Browse files
Refactoring the request
1 parent 4fe3a1f commit 4ef37f5

File tree

6 files changed

+210
-130
lines changed

6 files changed

+210
-130
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
5959
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
6060
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
61+
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
6162
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
6263
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
6364
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
@@ -153,7 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
153154
addAlibabaCloudSearchNamedWriteables(namedWriteables);
154155
addJinaAINamedWriteables(namedWriteables);
155156
addVoyageAINamedWriteables(namedWriteables);
156-
addCustomWriteables(namedWriteables);
157+
addCustomNamedWriteables(namedWriteables);
157158

158159
addUnifiedNamedWriteables(namedWriteables);
159160

@@ -163,6 +164,32 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
163164
return namedWriteables;
164165
}
165166

167+
private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
168+
namedWriteables.add(
169+
new NamedWriteableRegistry.Entry(
170+
ServiceSettings.class,
171+
CustomServiceSettings.NAME,
172+
CustomServiceSettings::new
173+
)
174+
);
175+
176+
namedWriteables.add(
177+
new NamedWriteableRegistry.Entry(
178+
TaskSettings.class,
179+
CustomTaskSettings.NAME,
180+
CustomTaskSettings::new
181+
)
182+
);
183+
184+
namedWriteables.add(
185+
new NamedWriteableRegistry.Entry(
186+
SecretSettings.class,
187+
CustomSecretSettings.NAME,
188+
CustomSecretSettings::new
189+
)
190+
);
191+
}
192+
166193
private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
167194
var writeables = UnifiedCompletionRequest.getNamedWriteables();
168195
namedWriteables.addAll(writeables);
@@ -665,11 +692,4 @@ private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> nam
665692
)
666693
);
667694
}
668-
669-
private static void addCustomWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
670-
namedWriteables.add(
671-
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
672-
);
673-
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));
674-
}
675695
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.Locale;
3232
import java.util.Map;
3333
import java.util.Objects;
34+
import java.util.function.Function;
3435
import java.util.stream.Collectors;
3536

3637
import static org.elasticsearch.core.Strings.format;
@@ -479,29 +480,86 @@ public static Map<String, Object> extractOptionalMap(
479480
}
480481

481482
/**
482-
* Ensures that each value in the map is a {@link String}.
483-
* @param map a Map to iterate over
484-
* @param settingName the setting name that his map corresponds to
485-
* @param validationException aggregated validation exceptions
483+
* Ensures the values of the map match one of the supplied types.
484+
* @param map Map to validate
485+
* @param allowedTypes List of {@link Class} to accept
486+
* @param settingName the setting name for the field
487+
* @param validationException exception to return if one of the values is invalid
488+
* @param censorValue if true the key and value will be included in the exception message
486489
*/
487-
public static void validateMapValueStrings(
490+
public static void validateMapValues(
488491
Map<String, Object> map,
492+
List<Class<?>> allowedTypes,
489493
String settingName,
490-
ValidationException validationException
494+
ValidationException validationException,
495+
boolean censorValue
491496
) {
492497
if (map == null) {
493498
return;
494499
}
495500

496501
for (var entry : map.entrySet()) {
497-
var value = entry.getValue();
498-
if (value instanceof String == false) {
499-
validationException.addValidationError(ServiceUtils.invalidTypeErrorMsg(settingName, value, String.class.getSimpleName()));
502+
boolean isAllowed = false;
503+
504+
for (Class<?> allowedType : allowedTypes) {
505+
if (allowedType.isInstance(entry.getValue())) {
506+
isAllowed = true;
507+
break;
508+
}
509+
}
510+
511+
Function<String[], String> errorMessage = (String[] validTypesAsStrings) -> {
512+
if (censorValue) {
513+
return Strings.format(
514+
"Map field [%s] has an entry that is not valid. Value type is not one of [%s].",
515+
settingName,
516+
String.join(", ", validTypesAsStrings)
517+
);
518+
} else {
519+
return Strings.format(
520+
"Map field [%s] has an entry that is not valid, [%s => %s]. Value type is not one of [%s].",
521+
settingName,
522+
entry.getKey(),
523+
entry.getValue(),
524+
String.join(", ", validTypesAsStrings)
525+
);
526+
}
527+
};
528+
529+
if (isAllowed == false) {
530+
var validTypesAsStrings = allowedTypes.stream().map(Class::toString).toArray(String[]::new);
531+
Arrays.sort(validTypesAsStrings);
532+
533+
validationException.addValidationError(errorMessage.apply(validTypesAsStrings));
500534
throw validationException;
501535
}
502536
}
503537
}
504538

539+
public static void convertMapStringsToSecureString(Map<String, Object> map) {
540+
if (map == null) {
541+
return;
542+
}
543+
544+
for (var entry : map.entrySet()) {
545+
var value = entry.getValue();
546+
if (value instanceof String) {
547+
map.put(entry.getKey(), new SecureString(((String) value).toCharArray()));
548+
}
549+
}
550+
}
551+
552+
/**
553+
* Removes null values.
554+
*/
555+
public static void removeNullValues(Map<String, Object> map) {
556+
if (map == null) {
557+
return;
558+
}
559+
560+
map.values().removeIf(Objects::isNull);
561+
}
562+
505563
public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
506564
Map<String, Object> map,
507565
String settingName,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomSecretSettings.java

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818

1919
import java.io.IOException;
2020
import java.util.HashMap;
21+
import java.util.List;
2122
import java.util.Map;
2223
import java.util.Objects;
2324

25+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertMapStringsToSecureString;
2426
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
27+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues;
28+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues;
2529

2630
public class CustomSecretSettings implements SecretSettings {
2731
public static final String NAME = "custom_secret_settings";
@@ -36,20 +40,15 @@ public static CustomSecretSettings fromMap(@Nullable Map<String, Object> map) {
3640
ValidationException validationException = new ValidationException();
3741

3842
Map<String, Object> requestSecretParamsMap = extractOptionalMap(map, SECRET_PARAMETERS, NAME, validationException);
43+
removeNullValues(requestSecretParamsMap);
44+
validateMapValues(requestSecretParamsMap, List.of(String.class), SECRET_PARAMETERS, validationException, true);
45+
convertMapStringsToSecureString(requestSecretParamsMap);
46+
3947
if (validationException.validationErrors().isEmpty() == false) {
4048
throw validationException;
4149
}
4250

43-
if (requestSecretParamsMap == null) {
44-
return null;
45-
} else {
46-
Map<String, Object> secureSecretParameters = new HashMap<>();
47-
for (String paramKey : requestSecretParamsMap.keySet()) {
48-
Object paramValue = requestSecretParamsMap.get(paramKey);
49-
secureSecretParameters.put(paramKey, paramValue);
50-
}
51-
return new CustomSecretSettings(secureSecretParameters);
52-
}
51+
return new CustomSecretSettings(Objects.requireNonNullElse(requestSecretParamsMap, new HashMap<>()));
5352
}
5453

5554
@Override
@@ -58,15 +57,11 @@ public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
5857
}
5958

6059
public CustomSecretSettings(@Nullable Map<String, Object> secretParameters) {
61-
this.secretParameters = secretParameters;
60+
this.secretParameters = Objects.requireNonNullElse(secretParameters, new HashMap<>());
6261
}
6362

6463
public CustomSecretSettings(StreamInput in) throws IOException {
65-
if (in.readBoolean()) {
66-
secretParameters = in.readGenericMap();
67-
} else {
68-
secretParameters = null;
69-
}
64+
secretParameters = in.readGenericMap();
7065
}
7166

7267
public Map<String, Object> getSecretParameters() {
@@ -76,7 +71,7 @@ public Map<String, Object> getSecretParameters() {
7671
@Override
7772
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
7873
builder.startObject();
79-
if (secretParameters != null) {
74+
if (secretParameters.isEmpty() == false) {
8075
builder.field(SECRET_PARAMETERS, secretParameters);
8176
}
8277
builder.endObject();
@@ -95,12 +90,7 @@ public TransportVersion getMinimalSupportedVersion() {
9590

9691
@Override
9792
public void writeTo(StreamOutput out) throws IOException {
98-
if (secretParameters != null) {
99-
out.writeBoolean(true);
100-
out.writeGenericMap(secretParameters);
101-
} else {
102-
out.writeBoolean(false);
103-
}
93+
out.writeGenericMap(secretParameters);
10494
}
10595

10696
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import java.io.IOException;
3434
import java.util.HashMap;
35+
import java.util.List;
3536
import java.util.Map;
3637
import java.util.Objects;
3738

@@ -43,8 +44,9 @@
4344
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
4445
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
4546
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
47+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues;
4648
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
47-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValueStrings;
49+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapValues;
4850

4951
public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings {
5052
public static final String NAME = "custom_service_settings";
@@ -84,7 +86,7 @@ private record Fields(
8486
RateLimitSettings rateLimitSettings
8587
) {
8688
public void validate(ValidationException validationException) {
87-
validateMapValueStrings(headers, HEADERS, validationException);
89+
validateMapValues(headers, List.of(String.class), HEADERS, validationException, false);
8890

8991
if (requestBodyMap == null || responseParserMap == null || jsonParserMap == null) {
9092
throw validationException;
@@ -113,6 +115,7 @@ private static Fields from(
113115
String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
114116

115117
Map<String, Object> headers = extractOptionalMap(map, HEADERS, ModelConfigurations.SERVICE_SETTINGS, validationException);
118+
removeNullValues(headers);
116119

117120
Map<String, Object> requestBodyMap = extractRequiredMap(map, REQUEST, ModelConfigurations.SERVICE_SETTINGS, validationException);
118121

@@ -165,10 +168,10 @@ private static Fields from(
165168
private static CustomServiceSettings fromRequestMap(Map<String, Object> map, TaskType taskType) {
166169
ValidationException validationException = new ValidationException();
167170

168-
var serviceSettings = from(map, ConfigurationParseContext.REQUEST, taskType, validationException);
171+
var serviceSettingsFields = from(map, ConfigurationParseContext.REQUEST, taskType, validationException);
169172

170-
serviceSettings.validate(validationException);
171-
return CustomServiceSettings.of(serviceSettings);
173+
serviceSettingsFields.validate(validationException);
174+
return CustomServiceSettings.of(serviceSettingsFields);
172175
}
173176

174177
private static CustomServiceSettings of(Fields fields) {

0 commit comments

Comments
 (0)