Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/133599.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133599
summary: Support Gemini thinking budget in inference API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_150_0_00);
public static final TransportVersion ESQL_LOOKUP_JOIN_PRE_JOIN_FILTER = def(9_151_0_00);
public static final TransportVersion INFERENCE_API_DISABLE_EIS_RATE_LIMITING = def(9_152_0_00);
public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_153_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
Expand Down Expand Up @@ -570,6 +571,14 @@ private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
GoogleVertexAiChatCompletionTaskSettings.NAME,
GoogleVertexAiChatCompletionTaskSettings::new
)
);

}

private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,25 @@ public static Integer extractOptionalPositiveInteger(
String settingName,
String scope,
ValidationException validationException
) {
return extractOptionalInteger(map, settingName, scope, validationException, true);
}

public static Integer extractOptionalInteger(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException
) {
return extractOptionalInteger(map, settingName, scope, validationException, false);
}

private static Integer extractOptionalInteger(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException,
boolean mustBePositive
) {
int initialValidationErrorCount = validationException.validationErrors().size();
Integer optionalField = ServiceUtils.removeAsType(map, settingName, Integer.class, validationException);
Expand All @@ -807,7 +826,7 @@ public static Integer extractOptionalPositiveInteger(
return null;
}

if (optionalField != null && optionalField <= 0) {
if (optionalField != null && mustBePositive && optionalField <= 0) {
validationException.addValidationError(ServiceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, optionalField));
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,14 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje

@Override
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = GoogleVertexAiChatCompletionModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);

var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
overriddenModel,
CHAT_COMPLETION_HANDLER,
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), overriddenModel),
ChatCompletionInput.class
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
Expand Down Expand Up @@ -47,7 +46,7 @@ public GoogleVertexAiChatCompletionModel(
taskType,
service,
GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
new EmptyTaskSettings(),
GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings),
GoogleVertexAiSecretSettings.fromMap(secrets)
);
}
Expand All @@ -57,7 +56,7 @@ public GoogleVertexAiChatCompletionModel(
TaskType taskType,
String service,
GoogleVertexAiChatCompletionServiceSettings serviceSettings,
EmptyTaskSettings taskSettings,
GoogleVertexAiChatCompletionTaskSettings taskSettings,
@Nullable GoogleVertexAiSecretSettings secrets
) {
super(
Expand All @@ -73,6 +72,14 @@ public GoogleVertexAiChatCompletionModel(
}
}

private GoogleVertexAiChatCompletionModel(
GoogleVertexAiChatCompletionModel model,
GoogleVertexAiChatCompletionTaskSettings taskSettings
) {
super(model, taskSettings);
streamingURI = model.streamingURI();
}

public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) {
var originalModelServiceSettings = model.getServiceSettings();

Expand All @@ -93,6 +100,26 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM
);
}

/**
* Overrides the task settings in the given model with the settings in the map. If no new settings are present or the provided settings
* do not differ from those already in the model, returns the original model
* @param model the model whose task settings will be overridden
* @param taskSettingsMap the new task settings to use
* @return a {@link GoogleVertexAiChatCompletionModel} with overridden {@link GoogleVertexAiChatCompletionTaskSettings}
*/
public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettingsMap) {
if (taskSettingsMap == null || taskSettingsMap.isEmpty()) {
return model;
}

var requestTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettingsMap);
if (requestTaskSettings.isEmpty() || model.getTaskSettings().equals(requestTaskSettings)) {
return model;
}
var combinedTaskSettings = GoogleVertexAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
return new GoogleVertexAiChatCompletionModel(model, combinedTaskSettings);
}

@Override
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this, taskSettings);
Expand All @@ -109,8 +136,8 @@ public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() {
}

@Override
public EmptyTaskSettings getTaskSettings() {
return (EmptyTaskSettings) super.getTaskSettings();
public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() {
return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -116,6 +117,11 @@ public String getWriteableName() {
return NAME;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
assert false : "should never be called when supportsVersion is used";
Expand Down Expand Up @@ -161,7 +167,7 @@ public int hashCode() {
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
public String toString() {
return Strings.toString(this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.completion;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class GoogleVertexAiChatCompletionTaskSettings implements TaskSettings {
public static final String NAME = "google_vertex_ai_chatcompletion_task_settings";

private final ThinkingConfig thinkingConfig;

public static final GoogleVertexAiChatCompletionTaskSettings EMPTY_SETTINGS = new GoogleVertexAiChatCompletionTaskSettings();
private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig();

public GoogleVertexAiChatCompletionTaskSettings() {
thinkingConfig = EMPTY_THINKING_CONFIG;
}

public GoogleVertexAiChatCompletionTaskSettings(ThinkingConfig thinkingConfig) {
this.thinkingConfig = Objects.requireNonNullElse(thinkingConfig, EMPTY_THINKING_CONFIG);
}

public GoogleVertexAiChatCompletionTaskSettings(StreamInput in) throws IOException {
thinkingConfig = new ThinkingConfig(in);
}

public static GoogleVertexAiChatCompletionTaskSettings fromMap(Map<String, Object> taskSettings) {
ValidationException validationException = new ValidationException();

// Extract optional thinkingConfig settings
ThinkingConfig thinkingConfig = ThinkingConfig.fromMap(taskSettings, validationException);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig);
}

public static GoogleVertexAiChatCompletionTaskSettings of(
GoogleVertexAiChatCompletionTaskSettings originalTaskSettings,
GoogleVertexAiChatCompletionTaskSettings newTaskSettings
) {
ThinkingConfig thinkingConfig = newTaskSettings.thinkingConfig().isEmpty()
? originalTaskSettings.thinkingConfig()
: newTaskSettings.thinkingConfig();
return new GoogleVertexAiChatCompletionTaskSettings(thinkingConfig);
}

public ThinkingConfig thinkingConfig() {
return thinkingConfig;
}

@Override
public boolean isEmpty() {
return thinkingConfig.isEmpty();
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
GoogleVertexAiChatCompletionTaskSettings newTaskSettings = GoogleVertexAiChatCompletionTaskSettings.fromMap(
new HashMap<>(newSettings)
);
return GoogleVertexAiChatCompletionTaskSettings.of(this, newTaskSettings);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.GEMINI_THINKING_BUDGET_ADDED;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
thinkingConfig.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
thinkingConfig.toXContent(builder, params);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
GoogleVertexAiChatCompletionTaskSettings that = (GoogleVertexAiChatCompletionTaskSettings) o;
return Objects.equals(thinkingConfig, that.thinkingConfig);
}

@Override
public int hashCode() {
return Objects.hashCode(thinkingConfig);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
Loading