Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SCRIPT_RESCORER = def(9_143_0_00);
public static final TransportVersion ESQL_LOOKUP_OPERATOR_EMITTED_ROWS = def(9_144_0_00);
public static final TransportVersion ALLOCATION_DECISION_NOT_PREFERRED = def(9_145_0_00);
public static final TransportVersion GEMINI_THINKING_BUDGET_ADDED = def(9_146_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,25 @@ public static Integer extractOptionalPositiveInteger(
String settingName,
String scope,
ValidationException validationException
) {
return extractOptionalInteger(map, settingName, scope, validationException, true);
}

public static Integer extractOptionalInteger(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException
) {
return extractOptionalInteger(map, settingName, scope, validationException, false);
}

private static Integer extractOptionalInteger(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException,
boolean mustBePositive
) {
int initialValidationErrorCount = validationException.validationErrors().size();
Integer optionalField = ServiceUtils.removeAsType(map, settingName, Integer.class, validationException);
Expand All @@ -788,7 +807,7 @@ public static Integer extractOptionalPositiveInteger(
return null;
}

if (optionalField != null && optionalField <= 0) {
if (optionalField != null && mustBePositive && optionalField <= 0) {
validationException.addValidationError(ServiceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, optionalField));
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionM
originalModelServiceSettings.projectId(),
originalModelServiceSettings.location(),
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()),
originalModelServiceSettings.rateLimitSettings()
originalModelServiceSettings.rateLimitSettings(),
originalModelServiceSettings.thinkingConfig()
);

return new GoogleVertexAiChatCompletionModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -44,12 +45,14 @@ public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXConten
private final String projectId;

private final RateLimitSettings rateLimitSettings;
private final ThinkingConfig thinkingConfig;

// https://cloud.google.com/vertex-ai/docs/quotas#eval-quotas
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1000);
private static final ThinkingConfig EMPTY_THINKING_CONFIG = new ThinkingConfig();

public GoogleVertexAiChatCompletionServiceSettings(StreamInput in) throws IOException {
this(in.readString(), in.readString(), in.readString(), new RateLimitSettings(in));
this(in.readString(), in.readString(), in.readString(), new RateLimitSettings(in), new ThinkingConfig(in));
}

@Override
Expand All @@ -58,6 +61,7 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
builder.field(LOCATION, location);
builder.field(MODEL_ID, modelId);
rateLimitSettings.toXContent(builder, params);
thinkingConfig.toXContent(builder, params);
return builder;
}

Expand All @@ -78,23 +82,34 @@ public static GoogleVertexAiChatCompletionServiceSettings fromMap(Map<String, Ob
context
);

// Extract optional thinkingConfig settings
ThinkingConfig thinkingConfig = ThinkingConfig.of(
map,
EMPTY_THINKING_CONFIG,
validationException,
GoogleVertexAiService.NAME,
context
);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings);
return new GoogleVertexAiChatCompletionServiceSettings(projectId, location, modelId, rateLimitSettings, thinkingConfig);
}

public GoogleVertexAiChatCompletionServiceSettings(
String projectId,
String location,
String modelId,
@Nullable RateLimitSettings rateLimitSettings
@Nullable RateLimitSettings rateLimitSettings,
@Nullable ThinkingConfig thinkingConfig
) {
this.projectId = projectId;
this.location = location;
this.modelId = modelId;
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.thinkingConfig = Objects.requireNonNullElse(thinkingConfig, EMPTY_THINKING_CONFIG);
}

public String location() {
Expand All @@ -116,6 +131,15 @@ public String getWriteableName() {
return NAME;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}

public ThinkingConfig thinkingConfig() {
return thinkingConfig;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
assert false : "should never be called when supportsVersion is used";
Expand All @@ -134,6 +158,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(location);
out.writeString(modelId);
rateLimitSettings.writeTo(out);
thinkingConfig.writeTo(out);
}

@Override
Expand All @@ -152,16 +177,17 @@ public boolean equals(Object o) {
return Objects.equals(location, that.location)
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(projectId, that.projectId)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(thinkingConfig, that.thinkingConfig);
}

@Override
public int hashCode() {
return Objects.hash(location, modelId, projectId, rateLimitSettings);
return Objects.hash(location, modelId, projectId, rateLimitSettings, thinkingConfig);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
public String toString() {
return Strings.toString(this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.completion;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;

/**
* This class encapsulates the ThinkingConfig object contained within GenerationConfig. Only the thinkingBudget field is currently
* supported, but the includeThoughts field may be added in the future
*/
public class ThinkingConfig implements Writeable, ToXContentFragment {
public static final String THINKING_CONFIG_FIELD = "thinking_config";
public static final String THINKING_BUDGET_FIELD = "thinking_budget";

private final Integer thinkingBudget;

/**
* Constructor for an empty {@code ThinkingConfig}
*/
public ThinkingConfig() {
this.thinkingBudget = null;
}

public ThinkingConfig(Integer thinkingBudget) {
this.thinkingBudget = thinkingBudget;
}

public ThinkingConfig(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
thinkingBudget = in.readOptionalVInt();
} else {
thinkingBudget = null;
}
}

public static ThinkingConfig of(
Map<String, Object> map,
ThinkingConfig defaultValue,
ValidationException validationException,
String serviceName,
ConfigurationParseContext context
) {
Map<String, Object> thinkingConfigSettings = removeFromMapOrDefaultEmpty(map, THINKING_CONFIG_FIELD);
Integer thinkingBudget = ServiceUtils.extractOptionalInteger(
thinkingConfigSettings,
THINKING_BUDGET_FIELD,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);

if (ConfigurationParseContext.isRequestContext(context)) {
throwIfNotEmptyMap(thinkingConfigSettings, serviceName);
}

return thinkingBudget == null ? defaultValue : new ThinkingConfig(thinkingBudget);
}

public boolean isEmpty() {
return thinkingBudget == null;
}

public Integer getThinkingBudget() {
return thinkingBudget;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.GEMINI_THINKING_BUDGET_ADDED)) {
out.writeOptionalVInt(thinkingBudget);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (thinkingBudget != null) {
builder.startObject(THINKING_CONFIG_FIELD);
builder.field(THINKING_BUDGET_FIELD, thinkingBudget);
builder.endObject();
}
return builder;
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
ThinkingConfig that = (ThinkingConfig) o;
return Objects.equals(thinkingBudget, that.thinkingBudget);
}

@Override
public int hashCode() {
return Objects.hashCode(thinkingBudget);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatIn
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(uri);

var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput);
var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(
unifiedChatInput,
model.getServiceSettings().thinkingConfig()
);

ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.ThinkingConfig;

import java.io.IOException;
import java.util.Locale;
Expand All @@ -37,6 +38,8 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont
private static final String TEMPERATURE = "temperature";
private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens";
private static final String TOP_P = "topP";
private static final String THINKING_CONFIG = "thinkingConfig";
private static final String THINKING_BUDGET = "thinkingBudget";

private static final String TOOLS = "tools";
private static final String FUNCTION_DECLARATIONS = "functionDeclarations";
Expand All @@ -56,6 +59,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont
private static final String FUNCTION_CALL_ARGS = "args";

private final UnifiedChatInput unifiedChatInput;
private final ThinkingConfig thinkingConfig;

private static final String USER_ROLE = "user";
private static final String MODEL_ROLE = "model";
Expand All @@ -66,8 +70,9 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont

private static final String SYSTEM_INSTRUCTION = "systemInstruction";

public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, ThinkingConfig thinkingConfig) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.thinkingConfig = Objects.requireNonNull(thinkingConfig);
}

private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) {
Expand Down Expand Up @@ -316,7 +321,8 @@ private void buildGenerationConfig(XContentBuilder builder) throws IOException {
boolean hasAnyConfig = request.stop() != null
|| request.temperature() != null
|| request.maxCompletionTokens() != null
|| request.topP() != null;
|| request.topP() != null
|| thinkingConfig.isEmpty() == false;

if (hasAnyConfig == false) {
return;
Expand All @@ -336,6 +342,11 @@ private void buildGenerationConfig(XContentBuilder builder) throws IOException {
if (request.topP() != null) {
builder.field(TOP_P, request.topP());
}
if (thinkingConfig.isEmpty() == false) {
builder.startObject(THINKING_CONFIG);
builder.field(THINKING_BUDGET, thinkingConfig.getThinkingBudget());
builder.endObject();
}

builder.endObject();
}
Expand Down
Loading