diff --git a/docs/changelog/133599.yaml b/docs/changelog/133599.yaml new file mode 100644 index 0000000000000..942b739c920f1 --- /dev/null +++ b/docs/changelog/133599.yaml @@ -0,0 +1,5 @@ +pr: 133599 +summary: Support Gemini thinking budget in inference API +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2d6c47eefffaf..3d8d4762e056a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -359,6 +359,7 @@ static TransportVersion def(int id) { public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_150_0_00); public static final TransportVersion ESQL_LOOKUP_JOIN_PRE_JOIN_FILTER = def(9_151_0_00); public static final TransportVersion INFERENCE_API_DISABLE_EIS_RATE_LIMITING = def(9_152_0_00); + public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_153_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index f8fb375022abb..35b3977b7049c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -89,6 +89,7 @@ import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings; @@ -570,6 +571,14 @@ private static void addGoogleVertexAiNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 3926db355bb94..7420f716efeae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -799,6 +799,25 @@ public static Integer extractOptionalPositiveInteger( String settingName, String scope, ValidationException validationException + ) { + return extractOptionalInteger(map, settingName, scope, validationException, true); + } + + public static Integer extractOptionalInteger( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + return extractOptionalInteger(map, settingName, scope, validationException, false); + } + + private static Integer extractOptionalInteger( + Map map, + String settingName, + String scope, + ValidationException validationException, + boolean mustBePositive ) { int initialValidationErrorCount = validationException.validationErrors().size(); Integer optionalField = ServiceUtils.removeAsType(map, settingName, Integer.class, validationException); @@ -807,7 +826,7 @@ public static Integer extractOptionalPositiveInteger( return null; } - if (optionalField != null && optionalField <= 0) { + if (optionalField != null && mustBePositive && optionalField <= 0) { validationException.addValidationError(ServiceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, optionalField)); return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java index 80d82df1cac26..b0034f587f363 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java @@ -73,13 +73,14 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings) { + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); var manager = new GenericRequestManager<>( serviceComponents.threadPool(), - model, + overriddenModel, CHAT_COMPLETION_HANDLER, - inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel), ChatCompletionInput.class ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java index fdb4ed34d92db..8e9174184223b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -9,7 +9,6 @@ import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -47,7 +46,7 @@ public GoogleVertexAiChatCompletionModel( taskType, service, GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context), - new EmptyTaskSettings(), + GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings), GoogleVertexAiSecretSettings.fromMap(secrets) ); } @@ -57,7 +56,7 @@ public GoogleVertexAiChatCompletionModel( TaskType taskType, String service, GoogleVertexAiChatCompletionServiceSettings serviceSettings, - EmptyTaskSettings taskSettings, + GoogleVertexAiChatCompletionTaskSettings taskSettings, @Nullable GoogleVertexAiSecretSettings secrets ) { super( @@ -73,6 +72,14 @@ public GoogleVertexAiChatCompletionModel( } } + private GoogleVertexAiChatCompletionModel( + GoogleVertexAiChatCompletionModel model, + GoogleVertexAiChatCompletionTaskSettings taskSettings + ) { + super(model, taskSettings); + streamingURI = model.streamingURI(); + } + public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) { var originalModelServiceSettings = model.getServiceSettings(); @@ -93,6 +100,26 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM ); } + /** + * Overrides the task settings in the given model with the settings in the map. If no new settings are present or the provided settings + * do not differ from those already in the model, returns the original model + * @param model the model whose task settings will be overridden + * @param taskSettingsMap the new task settings to use + * @return a {@link GoogleVertexAiChatCompletionModel} with overridden {@link GoogleVertexAiChatCompletionTaskSettings} + */ + public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, Map taskSettingsMap) { + if (taskSettingsMap == null || taskSettingsMap.isEmpty()) { + return model; + } + + var requestTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettingsMap); + if (requestTaskSettings.isEmpty() || model.getTaskSettings().equals(requestTaskSettings)) { + return model; + } + var combinedTaskSettings = GoogleVertexAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings); + return new GoogleVertexAiChatCompletionModel(model, combinedTaskSettings); + } + @Override public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); @@ -109,8 +136,8 @@ public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() { } @Override - public EmptyTaskSettings getTaskSettings() { - return (EmptyTaskSettings) super.getTaskSettings(); + public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() { + return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java index a753fc5dc66f2..cea44e3f3b780 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java @@ -9,6 +9,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -116,6 +117,11 @@ public String getWriteableName() { return NAME; } + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + @Override public TransportVersion getMinimalSupportedVersion() { assert false : "should never be called when supportsVersion is used"; @@ -161,7 +167,7 @@ public int hashCode() { } @Override - public RateLimitSettings rateLimitSettings() { - return rateLimitSettings; + public String toString() { + return Strings.toString(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java new file mode 100644 index 0000000000000..7b18ee1fc9fad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettings.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +public class GoogleVertexAiChatCompletionTaskSettings implements TaskSettings { + public static final String NAME = "google_vertex_ai_chatcompletion_task_settings"; + + private final ThinkingConfig thinkingConfig; + + public static final GoogleVertexAiChatCompletionTaskSettings EMPTY_SETTINGS = new GoogleVertexAiChatCompletionTaskSettings(); + private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig(); + + public GoogleVertexAiChatCompletionTaskSettings() { + thinkingConfig = EMPTY_THINKING_CONFIG; + } + + public GoogleVertexAiChatCompletionTaskSettings(ThinkingConfig thinkingConfig) { + this.thinkingConfig = Objects.requireNonNullElse(thinkingConfig, EMPTY_THINKING_CONFIG); + } + + public GoogleVertexAiChatCompletionTaskSettings(StreamInput in) throws IOException { + thinkingConfig = new ThinkingConfig(in); + } + + public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map taskSettings) { + ValidationException validationException = new ValidationException(); + + // Extract optional thinkingConfig settings + ThinkingConfig thinkingConfig = ThinkingConfig.fromMap(taskSettings, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig); + } + + public static GoogleVertexAiChatCompletionTaskSettings of( + GoogleVertexAiChatCompletionTaskSettings originalTaskSettings, + GoogleVertexAiChatCompletionTaskSettings newTaskSettings + ) { + ThinkingConfig thinkingConfig = newTaskSettings.thinkingConfig().isEmpty() + ? originalTaskSettings.thinkingConfig() + : newTaskSettings.thinkingConfig(); + return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig); + } + + public ThinkingConfig thinkingConfig() { + return thinkingConfig; + } + + @Override + public boolean isEmpty() { + return thinkingConfig.isEmpty(); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + GoogleVertexAiChatCompletionTaskSettings newTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap( + new HashMap<>(newSettings) + ); + return GoogleVertexAiChatCompletionTaskSettings.of(this, newTaskSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.GEMINI_THINKING_BUDGET_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + thinkingConfig.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + thinkingConfig.toXContent(builder, params); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + GoogleVertexAiChatCompletionTaskSettings that = (GoogleVertexAiChatCompletionTaskSettings) o; + return Objects.equals(thinkingConfig, that.thinkingConfig); + } + + @Override + public int hashCode() { + return Objects.hashCode(thinkingConfig); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/ThinkingConfig.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/ThinkingConfig.java new file mode 100644 index 0000000000000..74edc0cb7f711 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/ThinkingConfig.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; + +/** + * This class encapsulates the ThinkingConfig object contained within GenerationConfig. Only the thinkingBudget field is currently + * supported, but the includeThoughts field may be added in the future + * + * @see Gemini Thinking documentation + */ +public class ThinkingConfig implements Writeable, ToXContentFragment { + public static final String THINKING_CONFIG_FIELD = "thinking_config"; + public static final String THINKING_BUDGET_FIELD = "thinking_budget"; + + private final Integer thinkingBudget; + + /** + * Constructor for an empty {@code ThinkingConfig} + */ + public ThinkingConfig() { + this.thinkingBudget = null; + } + + public ThinkingConfig(Integer thinkingBudget) { + this.thinkingBudget = thinkingBudget; + } + + public ThinkingConfig(StreamInput in) throws IOException { + thinkingBudget = in.readOptionalVInt(); + } + + public static ThinkingConfig fromMap(Map map, ValidationException validationException) { + Map thinkingConfigSettings = removeFromMapOrDefaultEmpty(map, THINKING_CONFIG_FIELD); + Integer thinkingBudget = ServiceUtils.extractOptionalInteger( + thinkingConfigSettings, + THINKING_BUDGET_FIELD, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + return new ThinkingConfig(thinkingBudget); + } + + public boolean isEmpty() { + return thinkingBudget == null; + } + + public Integer getThinkingBudget() { + return thinkingBudget; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(thinkingBudget); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (thinkingBudget != null) { + builder.startObject(THINKING_CONFIG_FIELD); + builder.field(THINKING_BUDGET_FIELD, thinkingBudget); + builder.endObject(); + } + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + ThinkingConfig that = (ThinkingConfig) o; + return Objects.equals(thinkingBudget, that.thinkingBudget); + } + + @Override + public int hashCode() { + return Objects.hashCode(thinkingBudget); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java index 7acc859d26748..c6c699c7e85d5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java @@ -37,7 +37,10 @@ public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatIn public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(uri); - var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + model.getTaskSettings().thinkingConfig() + ); ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java index 7b8f75b2853bb..7e625530f197a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java @@ -19,6 +19,7 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; import java.io.IOException; import java.util.Locale; @@ -37,6 +38,8 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String TEMPERATURE = "temperature"; private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens"; private static final String TOP_P = "topP"; + private static final String THINKING_CONFIG = "thinkingConfig"; + private static final String THINKING_BUDGET = "thinkingBudget"; private static final String TOOLS = "tools"; private static final String FUNCTION_DECLARATIONS = "functionDeclarations"; @@ -56,6 +59,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String FUNCTION_CALL_ARGS = "args"; private final UnifiedChatInput unifiedChatInput; + private final ThinkingConfig thinkingConfig; private static final String USER_ROLE = "user"; private static final String MODEL_ROLE = "model"; @@ -66,8 +70,9 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont private static final String SYSTEM_INSTRUCTION = "systemInstruction"; - public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { + public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, ThinkingConfig thinkingConfig) { this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.thinkingConfig = Objects.requireNonNull(thinkingConfig); } private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) { @@ -316,7 +321,8 @@ private void buildGenerationConfig(XContentBuilder builder) throws IOException { boolean hasAnyConfig = request.stop() != null || request.temperature() != null || request.maxCompletionTokens() != null - || request.topP() != null; + || request.topP() != null + || thinkingConfig.isEmpty() == false; if (hasAnyConfig == false) { return; @@ -336,6 +342,11 @@ private void buildGenerationConfig(XContentBuilder builder) throws IOException { if (request.topP() != null) { builder.field(TOP_P, request.topP()); } + if (thinkingConfig.isEmpty() == false) { + builder.startObject(THINKING_CONFIG); + builder.field(THINKING_BUDGET, thinkingConfig.getThinkingBudget()); + builder.endObject(); + } builder.endObject(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 7e60e380b9071..f8de6887e515e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -31,6 +31,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalList; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalListOfStringTuples; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap; @@ -540,7 +541,7 @@ public void testExtractOptionalList_AddsException_WhenFieldContainsMixedTypeValu assertTrue(map.isEmpty()); } - public void testExtractOptionalPositiveInt() { + public void testExtractOptionalPositiveInteger_returnsInteger_withPositiveInteger() { var validation = new ValidationException(); validation.addValidationError("previous error"); Map map = modifiableMap(Map.of("abc", 1)); @@ -548,6 +549,77 @@ public void testExtractOptionalPositiveInt() { assertThat(validation.validationErrors(), hasSize(1)); } + public void testExtractOptionalPositiveInteger_returnsNull_whenSettingNotFound() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("abc", 1)); + assertThat(extractOptionalPositiveInteger(map, "not_abc", "scope", validation), is(nullValue())); + assertThat(validation.validationErrors(), hasSize(1)); + } + + public void testExtractOptionalPositiveInteger_returnsNull_addsValidationError_whenObjectIsNotInteger() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + String setting = "abc"; + Map map = modifiableMap(Map.of(setting, "not_an_int")); + assertThat(extractOptionalPositiveInteger(map, setting, "scope", validation), is(nullValue())); + assertThat(validation.validationErrors(), hasSize(2)); + assertThat(validation.validationErrors().getLast(), containsString("cannot be converted to a [Integer]")); + } + + public void testExtractOptionalPositiveInteger_returnNull_addsValidationError_withNonPositiveInteger() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + String zeroKey = "zero"; + String negativeKey = "negative"; + Map map = modifiableMap(Map.of(zeroKey, 0, negativeKey, -1)); + + // Test zero + assertThat(extractOptionalPositiveInteger(map, zeroKey, "scope", validation), is(nullValue())); + assertThat(validation.validationErrors(), hasSize(2)); + assertThat(validation.validationErrors().getLast(), containsString("[" + zeroKey + "] must be a positive integer")); + + // Test a negative number + assertThat(extractOptionalPositiveInteger(map, negativeKey, "scope", validation), is(nullValue())); + assertThat(validation.validationErrors(), hasSize(3)); + assertThat(validation.validationErrors().getLast(), containsString("[" + negativeKey + "] must be a positive integer")); + } + + public void testExtractOptionalInteger_returnsInteger() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + String positiveKey = "positive"; + int positiveValue = 123; + String zeroKey = "zero"; + int zeroValue = 0; + String negativeKey = "negative"; + int negativeValue = -123; + Map map = modifiableMap(Map.of(positiveKey, positiveValue, zeroKey, zeroValue, negativeKey, negativeValue)); + + assertThat(extractOptionalInteger(map, positiveKey, "scope", validation), is(positiveValue)); + assertThat(extractOptionalInteger(map, zeroKey, "scope", validation), is(zeroValue)); + assertThat(extractOptionalInteger(map, negativeKey, "scope", validation), is(negativeValue)); + assertThat(validation.validationErrors(), hasSize(1)); + } + + public void testExtractOptionalInteger_returnsNull_whenSettingNotFound() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("abc", 1)); + assertThat(extractOptionalInteger(map, "not_abc", "scope", validation), is(nullValue())); + assertThat(validation.validationErrors(), hasSize(1)); + } + + public void testExtractOptionalInteger_returnsNull_addsValidationError_whenObjectIsNotInteger() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + String setting = "abc"; + Map map = modifiableMap(Map.of(setting, "not_an_int")); + assertThat(extractOptionalInteger(map, setting, "scope", validation), is(nullValue())); + assertThat(validation.validationErrors(), hasSize(2)); + assertThat(validation.validationErrors().getLast(), containsString("cannot be converted to a [Integer]")); + } + public void testExtractOptionalPositiveLong_IntegerValue() { var validation = new ValidationException(); validation.addValidationError("previous error"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java index a634b04d03d9e..0e720f60dfe2d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.junit.After; @@ -124,7 +125,8 @@ private ExecutableAction createAction(String location, String projectId, String location, actualModelId, "api-key", - new RateLimitSettings(100) + new RateLimitSettings(100), + new ThinkingConfig(256) ); var manager = new GenericRequestManager<>( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java index 6a0ec6edfaa79..cb9fd803047bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.completion; import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; @@ -17,8 +16,12 @@ import java.net.URI; import java.net.URISyntaxException; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_BUDGET_FIELD; +import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_CONFIG_FIELD; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -31,9 +34,17 @@ public class GoogleVertexAiChatCompletionModelTests extends ESTestCase { private static final String DEFAULT_MODEL_ID = "gemini-pro"; private static final String DEFAULT_API_KEY = "test-api-key"; private static final RateLimitSettings DEFAULT_RATE_LIMIT = new RateLimitSettings(100); + private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig(); public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { - var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT); + var model = createCompletionModel( + DEFAULT_PROJECT_ID, + DEFAULT_LOCATION, + DEFAULT_MODEL_ID, + DEFAULT_API_KEY, + DEFAULT_RATE_LIMIT, + EMPTY_THINKING_CONFIG + ); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), "gemini-flash", @@ -54,11 +65,18 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); - assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); + assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(EMPTY_THINKING_CONFIG)); } public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { - var model = createCompletionModel(DEFAULT_PROJECT_ID, DEFAULT_LOCATION, DEFAULT_MODEL_ID, DEFAULT_API_KEY, DEFAULT_RATE_LIMIT); + var model = createCompletionModel( + DEFAULT_PROJECT_ID, + DEFAULT_LOCATION, + DEFAULT_MODEL_ID, + DEFAULT_API_KEY, + DEFAULT_RATE_LIMIT, + EMPTY_THINKING_CONFIG + ); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)), null, @@ -78,7 +96,7 @@ public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenReques assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); - assertThat(overriddenModel.getTaskSettings(), is(model.getTaskSettings())); + assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(EMPTY_THINKING_CONFIG)); assertThat(overriddenModel, not(sameInstance(model))); } @@ -95,19 +113,66 @@ public void testBuildUri() throws URISyntaxException { assertThat(actualUri, is(expectedUri)); } + public void testOf_overridesTaskSettings_whenPresent() { + var model = createCompletionModel( + DEFAULT_PROJECT_ID, + DEFAULT_LOCATION, + DEFAULT_MODEL_ID, + DEFAULT_API_KEY, + DEFAULT_RATE_LIMIT, + new ThinkingConfig(123) + ); + int newThinkingBudget = 456; + Map taskSettings = new HashMap<>( + Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, newThinkingBudget))) + ); + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings); + + assertThat(overriddenModel.getServiceSettings().modelId(), is(DEFAULT_MODEL_ID)); + assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); + assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); + assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); + assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); + + assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(new ThinkingConfig(newThinkingBudget))); + } + + public void testOf_doesNotOverrideTaskSettings_whenNotPresent() { + ThinkingConfig originalThinkingConfig = new ThinkingConfig(123); + var model = createCompletionModel( + DEFAULT_PROJECT_ID, + DEFAULT_LOCATION, + DEFAULT_MODEL_ID, + DEFAULT_API_KEY, + DEFAULT_RATE_LIMIT, + originalThinkingConfig + ); + Map taskSettings = new HashMap<>(Map.of(THINKING_CONFIG_FIELD, new HashMap<>())); + var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings); + + assertThat(overriddenModel.getServiceSettings().modelId(), is(DEFAULT_MODEL_ID)); + assertThat(overriddenModel.getServiceSettings().projectId(), is(DEFAULT_PROJECT_ID)); + assertThat(overriddenModel.getServiceSettings().location(), is(DEFAULT_LOCATION)); + assertThat(overriddenModel.getServiceSettings().rateLimitSettings(), is(DEFAULT_RATE_LIMIT)); + assertThat(overriddenModel.getSecretSettings().serviceAccountJson(), equalTo(new SecureString(DEFAULT_API_KEY.toCharArray()))); + + assertThat(overriddenModel.getTaskSettings().thinkingConfig(), is(originalThinkingConfig)); + } + public static GoogleVertexAiChatCompletionModel createCompletionModel( String projectId, String location, String modelId, String apiKey, - RateLimitSettings rateLimitSettings + RateLimitSettings rateLimitSettings, + ThinkingConfig thinkingConfig ) { return new GoogleVertexAiChatCompletionModel( "google-vertex-ai-chat-test-id", TaskType.CHAT_COMPLETION, "google_vertex_ai", new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings), - new EmptyTaskSettings(), + new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig), new GoogleVertexAiSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettingsTests.java new file mode 100644 index 0000000000000..cc567b24fe773 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionTaskSettingsTests.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_BUDGET_FIELD; +import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_CONFIG_FIELD; +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiChatCompletionTaskSettingsTests extends InferenceSettingsTestCase { + + public void testUpdatedTaskSettings_updatesTaskSettingsWhenDifferent() { + var initialSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(123)); + int updatedThinkingBudget = 456; + Map newSettingsMap = new HashMap<>( + Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, updatedThinkingBudget))) + ); + + GoogleVertexAiChatCompletionTaskSettings updatedSettings = (GoogleVertexAiChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(newSettingsMap); + assertThat(updatedSettings.thinkingConfig().getThinkingBudget(), is(updatedThinkingBudget)); + } + + public void testUpdatedTaskSettings_doesNotUpdateTaskSettingsWhenNewSettingsAreEmpty() { + var initialSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(123)); + Map emptySettingsMap = new HashMap<>(Map.of(THINKING_CONFIG_FIELD, new HashMap<>())); + + GoogleVertexAiChatCompletionTaskSettings updatedSettings = (GoogleVertexAiChatCompletionTaskSettings) initialSettings + .updatedTaskSettings(emptySettingsMap); + assertThat(updatedSettings.thinkingConfig().getThinkingBudget(), is(initialSettings.thinkingConfig().getThinkingBudget())); + } + + public void testFromMap_returnsSettings() { + int thinkingBudget = 256; + Map settings = new HashMap<>( + Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, thinkingBudget))) + ); + + var result = GoogleVertexAiChatCompletionTaskSettings.fromMap(settings); + assertThat(result.thinkingConfig().getThinkingBudget(), is(thinkingBudget)); + } + + public void testFromMap_throwsWhenValidationErrorEncountered() { + Map settings = new HashMap<>( + Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, "not_an_int"))) + ); + + expectThrows(ValidationException.class, () -> GoogleVertexAiChatCompletionTaskSettings.fromMap(settings)); + } + + public void testOf_overridesOriginalSettings_whenNewSettingsPresent() { + // Confirm we can overwrite empty settings + var originalSettings = new GoogleVertexAiChatCompletionTaskSettings(); + int newThinkingBudget = 123; + var newSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(newThinkingBudget)); + var updatedSettings = GoogleVertexAiChatCompletionTaskSettings.of(originalSettings, newSettings); + + assertThat(updatedSettings.thinkingConfig().getThinkingBudget(), is(newThinkingBudget)); + + // Confirm we can overwrite existing settings + int secondNewThinkingBudget = 456; + var secondNewSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(secondNewThinkingBudget)); + var secondUpdatedSettings = GoogleVertexAiChatCompletionTaskSettings.of(updatedSettings, secondNewSettings); + + assertThat(secondUpdatedSettings.thinkingConfig().getThinkingBudget(), is(secondNewThinkingBudget)); + } + + public void testOf_doesNotOverrideOriginalSettings_whenNewSettingsNotPresent() { + int originalThinkingBudget = 123; + var originalSettings = new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(originalThinkingBudget)); + var emptySettings = new GoogleVertexAiChatCompletionTaskSettings(); + var updatedSettings = GoogleVertexAiChatCompletionTaskSettings.of(originalSettings, emptySettings); + + assertThat(updatedSettings.thinkingConfig().getThinkingBudget(), is(originalThinkingBudget)); + } + + @Override + protected GoogleVertexAiChatCompletionTaskSettings fromMutableMap(Map mutableMap) { + return GoogleVertexAiChatCompletionTaskSettings.fromMap(mutableMap); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleVertexAiChatCompletionTaskSettings::new; + } + + @Override + protected GoogleVertexAiChatCompletionTaskSettings createTestInstance() { + return new GoogleVertexAiChatCompletionTaskSettings(new ThinkingConfig(randomInt())); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/ThinkingConfigTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/ThinkingConfigTests.java new file mode 100644 index 0000000000000..76a5809589514 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/ThinkingConfigTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_BUDGET_FIELD; +import static org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig.THINKING_CONFIG_FIELD; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class ThinkingConfigTests extends AbstractBWCWireSerializationTestCase { + + public void testNoArgConstructor_createsEmptyConfig() { + ThinkingConfig thinkingConfig = new ThinkingConfig(); + assertThat(thinkingConfig.isEmpty(), is(true)); + assertThat(thinkingConfig.getThinkingBudget(), is(nullValue())); + } + + public void testNullValueInConstructor_createsEmptyConfig() { + Integer nullInt = null; + ThinkingConfig thinkingConfig = new ThinkingConfig(nullInt); + assertThat(thinkingConfig.isEmpty(), is(true)); + assertThat(thinkingConfig.getThinkingBudget(), is(nullValue())); + } + + public void testNonNullValueInConstructor() { + Integer thinkingBudget = 256; + ThinkingConfig thinkingConfig = new ThinkingConfig(thinkingBudget); + assertThat(thinkingConfig.isEmpty(), is(false)); + assertThat(thinkingConfig.getThinkingBudget(), is(thinkingBudget)); + } + + public void testFromMap_withThinkingConfigSpecified_andThinkingBudgetSpecified() { + ValidationException exception = new ValidationException(); + int thinkingBudget = 256; + Map settings = new HashMap<>( + Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, thinkingBudget))) + ); + + ThinkingConfig result = ThinkingConfig.fromMap(settings, exception); + + assertThat(result, is(new ThinkingConfig(thinkingBudget))); + assertThat(exception.validationErrors(), is(empty())); + } + + public void testFromMap_returnsEmptyThinkingConfig_withThinkingConfigSpecified_andThinkingBudgetNotSpecified() { + ValidationException exception = new ValidationException(); + Map settings = new HashMap<>(Map.of(THINKING_CONFIG_FIELD, new HashMap<>())); + + ThinkingConfig result = ThinkingConfig.fromMap(settings, exception); + + assertThat(result.isEmpty(), is(true)); + assertThat(exception.validationErrors(), is(empty())); + } + + public void testFromMap_returnsEmptyThinkingConfig_withThinkingConfigNotSpecified_andThinkingBudgetSpecified() { + ValidationException exception = new ValidationException(); + int thinkingBudget = 256; + Map settings = new HashMap<>( + Map.of("not_thinking_config", new HashMap<>(Map.of(THINKING_BUDGET_FIELD, thinkingBudget))) + ); + + ThinkingConfig result = ThinkingConfig.fromMap(settings, exception); + + assertThat(result.isEmpty(), is(true)); + assertThat(exception.validationErrors(), is(empty())); + } + + public void testFromMap_returnsEmptyThinkingConfig_withUnknownField() { + ValidationException exception = new ValidationException(); + int anInt = 42; + Map settings = new HashMap<>(Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of("not_thinking_budget", anInt)))); + + ThinkingConfig result = ThinkingConfig.fromMap(settings, exception); + + assertThat(result.isEmpty(), is(true)); + assertThat(exception.validationErrors(), is(empty())); + } + + public void testFromMap_returnsEmptyThinkingConfig_addsException_whenFieldIsNotInteger() { + ValidationException exception = new ValidationException(); + String notAnInt = "not_an_int"; + Map settings = new HashMap<>(Map.of(THINKING_CONFIG_FIELD, new HashMap<>(Map.of(THINKING_BUDGET_FIELD, notAnInt)))); + + ThinkingConfig result = ThinkingConfig.fromMap(settings, exception); + + assertThat(result.isEmpty(), is(true)); + assertThat(exception.validationErrors(), hasSize(1)); + assertThat( + exception.validationErrors().getLast(), + is("field [thinking_budget] is not of the expected type. The value [" + notAnInt + "] cannot be converted to a [Integer]") + ); + } + + public void testToXContent() throws IOException { + ThinkingConfig thinkingConfig = new ThinkingConfig(256); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + thinkingConfig.toXContent(builder, null); + builder.endObject(); + String xContentResult = Strings.toString(builder); + + String expected = XContentHelper.stripWhitespace(""" + { + "thinking_config": { + "thinking_budget": 256 + } + } + """); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return ThinkingConfig::new; + } + + @Override + protected ThinkingConfig createTestInstance() { + return new ThinkingConfig(256); + } + + @Override + protected ThinkingConfig mutateInstance(ThinkingConfig instance) throws IOException { + Integer originalThinkingBudget = instance.getThinkingBudget(); + Integer newThinkingBudget = randomValueOtherThan(originalThinkingBudget, ESTestCase::randomInt); + return new ThinkingConfig(newThinkingBudget); + } + + @Override + protected ThinkingConfig mutateInstanceForVersion(ThinkingConfig instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java index 261a6c2153b04..d33fba0c31806 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; import java.io.IOException; import java.util.ArrayList; @@ -30,6 +31,8 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntityTests extends ESTes private static final String USER_ROLE = "user"; private static final String ASSISTANT_ROLE = "assistant"; + private static final ThinkingConfig thinkingConfig = new ThinkingConfig(256); + private static final ThinkingConfig emptyThinkingConfig = new ThinkingConfig(); public void testBasicSerialization_SingleMessage() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( @@ -43,7 +46,10 @@ public void testBasicSerialization_SingleMessage() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); // stream doesn't affect VertexAI request body - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -86,7 +92,10 @@ public void testSerialization_MultipleMessages() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messages); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -143,7 +152,10 @@ public void testSerialization_Tools() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -209,7 +221,10 @@ public void testSerialization_ToolsChoice() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, false); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -265,7 +280,10 @@ public void testSerialization_WithAllGenerationConfig() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + thinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -283,7 +301,10 @@ public void testSerialization_WithAllGenerationConfig() throws IOException { "stopSequences": ["stop1", "stop2"], "temperature": 0.5, "maxOutputTokens": 100, - "topP": 0.9 + "topP": 0.9, + "thinkingConfig": { + "thinkingBudget": 256 + } } } """; @@ -310,7 +331,10 @@ public void testSerialization_WithSomeGenerationConfig() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithGenerationConfig, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -333,6 +357,46 @@ public void testSerialization_WithSomeGenerationConfig() throws IOException { assertJsonEquals(jsonString, expectedJson); } + public void testSerialization_WithOnlyThinkingConfig() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Partial config."), + USER_ROLE, + null, + null + ); + + // No generation config fields set on unifiedRequest + var completionRequestWithNoGenerationConfig = UnifiedCompletionRequest.of(List.of(message)); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(completionRequestWithNoGenerationConfig, true); + + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + thinkingConfig + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "contents": [ + { + "role": "user", + "parts": [ { "text": "Partial config." } ] + } + ], + "generationConfig": { + "thinkingConfig": { + "thinkingBudget": 256 + } + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + public void testSerialization_NoGenerationConfig() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("No extra config."), @@ -345,7 +409,10 @@ public void testSerialization_NoGenerationConfig() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -381,7 +448,10 @@ public void testSerialization_WithContentObjects() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(messageList); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -414,7 +484,10 @@ public void testError_UnsupportedRole() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); var builder = JsonXContent.contentBuilder(); var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); @@ -435,7 +508,10 @@ public void testError_UnsupportedContentObjectType() throws IOException { var unifiedRequest = UnifiedCompletionRequest.of(List.of(message)); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, false); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); var builder = JsonXContent.contentBuilder(); var statusException = assertThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); @@ -471,7 +547,10 @@ public void testParseAllFields() throws IOException { ], "temperature": 0.1, "maxOutputTokens": 100, - "topP": 0.2 + "topP": 0.2, + "thinkingConfig": { + "thinkingBudget": 256 + } }, "tools": [ { @@ -535,7 +614,10 @@ public void testParseAllFields() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + thinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -589,7 +671,10 @@ public void testParseFunctionCallNoContent() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -629,7 +714,8 @@ public void testParseFunctionCallWithBadJson() throws IOException { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(requestContentObject, true); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput + unifiedChatInput, + emptyThinkingConfig ); XContentBuilder builder = JsonXContent.contentBuilder(); @@ -711,7 +797,8 @@ public void testParseFunctionCallWithEmptyStringContent() throws IOException { for (var request : requests) { UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( - unifiedChatInput + unifiedChatInput, + emptyThinkingConfig ); XContentBuilder builder = JsonXContent.contentBuilder(); @@ -755,7 +842,10 @@ public void testParseToolChoiceString() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -819,7 +909,10 @@ public void testBuildSystemMessage_MultipleParts() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -874,7 +967,10 @@ public void testBuildSystemMessageMul() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -903,7 +999,10 @@ public void testParseToolChoiceInvalid_throwElasticSearchStatusException() throw ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); var statusException = expectThrows(ElasticsearchStatusException.class, () -> entity.toXContent(builder, ToXContent.EMPTY_PARAMS)); @@ -987,7 +1086,10 @@ public void testParseMultipleTools() throws IOException { ); UnifiedChatInput unifiedChatInput = new UnifiedChatInput(request, true); - GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); + GoogleVertexAiUnifiedChatCompletionRequestEntity entity = new GoogleVertexAiUnifiedChatCompletionRequestEntity( + unifiedChatInput, + emptyThinkingConfig + ); XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java index 91c3eae4b72a6..aa4dff03a962a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; @@ -40,7 +41,7 @@ public void testCreateRequest_Default() throws IOException { var messages = List.of("Hello Gemini!"); - var request = createRequest(projectId, location, modelId, messages, null, null); + var request = createRequest(projectId, location, modelId, messages, null, null, null); var httpRequest = request.createHttpRequest(); var httpPost = (HttpPost) httpRequest.httpRequestBase(); @@ -67,27 +68,22 @@ public void testCreateRequest_Default() throws IOException { } - public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( - UnifiedChatInput input, - GoogleVertexAiChatCompletionModel model - ) { - return new GoogleVertexAiUnifiedChatCompletionWithoutAuthRequest(input, model); - } - public static GoogleVertexAiUnifiedChatCompletionRequest createRequest( String projectId, String location, String modelId, List messages, @Nullable String apiKey, - @Nullable RateLimitSettings rateLimitSettings + @Nullable RateLimitSettings rateLimitSettings, + @Nullable ThinkingConfig thinkingConfig ) { var model = GoogleVertexAiChatCompletionModelTests.createCompletionModel( projectId, location, modelId, Objects.requireNonNullElse(apiKey, "default-api-key"), - Objects.requireNonNullElse(rateLimitSettings, new RateLimitSettings(100)) + Objects.requireNonNullElse(rateLimitSettings, new RateLimitSettings(100)), + thinkingConfig ); var unifiedChatInput = new UnifiedChatInput(messages, "user", true);