-
Notifications
You must be signed in to change notification settings - Fork 1
Vertexai chatcompletion #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
d429a31
00bfdb0
2378270
1f00974
970ab3c
5428074
ee44f22
8160c2b
ff68fbe
29c7093
bfd75b0
2ebfac9
679ea80
e611cc3
87e428a
193d06d
813a2e8
f1ab8cc
23c7d92
c45d23f
a820d83
d2f09cf
bda94de
5dee072
2f75788
b50c911
d6ae90f
cbb387f
7e1c970
5ab716f
28aa464
2279391
85af5c0
16c01b0
1732244
6cc165b
c020122
7821d58
8633659
06020cc
5a2cfe5
0cf1f3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.completion; | ||
|
|
||
| import org.apache.http.client.utils.URIBuilder; | ||
| import org.elasticsearch.core.Nullable; | ||
| import org.elasticsearch.inference.ModelConfigurations; | ||
| import org.elasticsearch.inference.ModelSecrets; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.inference.UnifiedCompletionRequest; | ||
| import org.elasticsearch.xpack.inference.external.action.ExecutableAction; | ||
| import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings; | ||
|
|
||
| import java.net.URISyntaxException; | ||
| import java.util.Map; | ||
| import java.net.URI; | ||
| import java.util.Objects; | ||
|
|
||
| import static org.elasticsearch.core.Strings.format; | ||
|
|
||
| public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel { | ||
| public GoogleVertexAiChatCompletionModel( | ||
| String inferenceEntityId, | ||
| TaskType taskType, | ||
| String service, | ||
| Map<String, Object> serviceSettings, | ||
| Map<String, Object> taskSettings, | ||
| Map<String, Object> secrets, | ||
| ConfigurationParseContext context | ||
| ) { | ||
| this( | ||
| inferenceEntityId, | ||
| taskType, | ||
| service, | ||
| GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context), | ||
| GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings), | ||
| GoogleVertexAiSecretSettings.fromMap(secrets) | ||
| ); | ||
| } | ||
|
|
||
| GoogleVertexAiChatCompletionModel( | ||
| String inferenceEntityId, | ||
| TaskType taskType, | ||
| String service, | ||
| GoogleVertexAiChatCompletionServiceSettings serviceSettings, | ||
| GoogleVertexAiChatCompletionTaskSettings taskSettings, | ||
| @Nullable GoogleVertexAiSecretSettings secrets | ||
| ) { | ||
| super( | ||
| new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), | ||
| new ModelSecrets(secrets), | ||
| serviceSettings | ||
| ); | ||
| try { | ||
| this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); | ||
| } catch (URISyntaxException e) { | ||
| throw new RuntimeException(e); | ||
| } | ||
| } | ||
|
|
||
| public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) { | ||
| var originalModelServiceSettings = model.getServiceSettings(); | ||
|
|
||
| var newServiceSettings = new GoogleVertexAiChatCompletionServiceSettings( | ||
| originalModelServiceSettings.projectId(), | ||
| originalModelServiceSettings.location(), | ||
| Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), | ||
| originalModelServiceSettings.rateLimitSettings() | ||
| ); | ||
|
|
||
| return new GoogleVertexAiChatCompletionModel( | ||
| model.getInferenceEntityId(), | ||
| model.getTaskType(), | ||
| model.getConfigurations().getService(), | ||
| newServiceSettings, | ||
| model.getTaskSettings(), | ||
| model.getSecretSettings() | ||
| ); | ||
| } | ||
|
|
||
| public GoogleVertexAiChatCompletionModel( | ||
| ModelConfigurations configurations, | ||
| ModelSecrets secrets, | ||
| GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings | ||
| ) { | ||
| super(configurations, secrets, rateLimitServiceSettings); | ||
| } | ||
|
|
||
| @Override | ||
| public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) { | ||
| return visitor.create(this, taskSettings); | ||
| } | ||
|
|
||
| @Override | ||
| public GoogleDiscoveryEngineRateLimitServiceSettings rateLimitServiceSettings() { | ||
| return (GoogleDiscoveryEngineRateLimitServiceSettings) super.rateLimitServiceSettings(); | ||
| } | ||
|
|
||
| @Override | ||
| public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() { | ||
| return (GoogleVertexAiChatCompletionServiceSettings) super.getServiceSettings(); | ||
| } | ||
|
|
||
| @Override | ||
| public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() { | ||
| return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings(); | ||
| } | ||
|
|
||
| @Override | ||
| public GoogleVertexAiSecretSettings getSecretSettings() { | ||
| return (GoogleVertexAiSecretSettings) super.getSecretSettings(); | ||
| } | ||
|
|
||
| public static URI buildUri(String location, String projectId, String model) throws URISyntaxException { | ||
| return new URIBuilder().setScheme("https") | ||
| .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) | ||
| .setPathSegments( | ||
| GoogleVertexAiUtils.V1, | ||
| GoogleVertexAiUtils.PROJECTS, | ||
| projectId, | ||
| GoogleVertexAiUtils.LOCATIONS, | ||
| GoogleVertexAiUtils.GLOBAL, | ||
| GoogleVertexAiUtils.PUBLISHERS, | ||
| GoogleVertexAiUtils.PUBLISHER_GOOGLE, | ||
| GoogleVertexAiUtils.MODELS, | ||
| format("%s:%s", model, GoogleVertexAiUtils.STREAM_GENERATE_CONTENT) | ||
| ) | ||
| .build(); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.request; | ||
|
|
||
| import org.apache.http.HttpHeaders; | ||
| import org.apache.http.client.methods.HttpPost; | ||
| import org.apache.http.entity.ByteArrayEntity; | ||
| import org.elasticsearch.common.Strings; | ||
| import org.elasticsearch.xcontent.XContentType; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
| import org.elasticsearch.xpack.inference.external.request.HttpRequest; | ||
| import org.elasticsearch.xpack.inference.external.request.Request; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; | ||
|
|
||
| import java.net.URI; | ||
| import java.nio.charset.StandardCharsets; | ||
| import java.util.Objects; | ||
|
|
||
| public class GoogleVertexAiUnifiedChatCompletionRequest implements GoogleVertexAiRequest { | ||
|
|
||
| private final GoogleVertexAiChatCompletionModel model; | ||
| private final UnifiedChatInput unifiedChatInput; | ||
|
|
||
| public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { | ||
| this.model = Objects.requireNonNull(model); | ||
| this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); | ||
| } | ||
|
|
||
| @Override | ||
| public HttpRequest createHttpRequest() { | ||
| HttpPost httpPost = new HttpPost(model.uri()); | ||
|
|
||
| var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); | ||
|
|
||
| ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); | ||
| httpPost.setEntity(byteEntity); | ||
|
|
||
| httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); | ||
|
|
||
| decorateWithAuth(httpPost); | ||
| return new HttpRequest(httpPost, getInferenceEntityId()); | ||
| } | ||
|
|
||
| public void decorateWithAuth(HttpPost httpPost) { | ||
| GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings()); | ||
| } | ||
|
|
||
| @Override | ||
| public URI getURI() { | ||
| return model.uri(); | ||
| } | ||
|
|
||
| @Override | ||
| public Request truncate() { | ||
| // No truncation for Google VertexAI Chat completions | ||
| return this; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean[] getTruncationInfo() { | ||
| // No truncation for Google VertexAI Chat completions | ||
| return null; | ||
| } | ||
|
|
||
| @Override | ||
| public String getInferenceEntityId() { | ||
| return model.getInferenceEntityId(); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.services.googlevertexai.request; | ||
|
|
||
| import org.elasticsearch.inference.UnifiedCompletionRequest; | ||
| import org.elasticsearch.xcontent.ToXContentObject; | ||
| import org.elasticsearch.xcontent.XContentBuilder; | ||
| import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
| import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Objects; | ||
|
|
||
| import static org.elasticsearch.core.Strings.format; | ||
|
|
||
| public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXContentObject { | ||
| // Field names matching the Google Vertex AI API structure | ||
| private static final String CONTENTS = "contents"; | ||
| private static final String ROLE = "role"; | ||
| private static final String PARTS = "parts"; | ||
| private static final String TEXT = "text"; | ||
| private static final String GENERATION_CONFIG = "generationConfig"; | ||
| private static final String TEMPERATURE = "temperature"; | ||
| private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens"; | ||
| private static final String TOP_P = "topP"; | ||
| // TODO: Add other generationConfig fields if needed (e.g., stopSequences, topK) | ||
|
|
||
| private final UnifiedChatInput unifiedChatInput; | ||
| private final GoogleVertexAiChatCompletionModel model; | ||
|
|
||
| private static final String USER_ROLE = "user"; | ||
| private static final String MODEL_ROLE = "model"; | ||
| private static final String STOP_SEQUENCES = "stopSequences"; | ||
|
|
||
| public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { | ||
| this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); | ||
| this.model = Objects.requireNonNull(model); // Keep the model reference | ||
| } | ||
|
|
||
| private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) throws IOException { | ||
| var messageRoleLowered = messageRole.toLowerCase(); | ||
|
|
||
| if (messageRoleLowered.equals(USER_ROLE) || messageRoleLowered.equals(MODEL_ROLE)) { | ||
| return messageRoleLowered; | ||
| } | ||
|
|
||
| // TODO: Here is OK to throw an IOException? | ||
|
||
| throw new IOException( | ||
| format( | ||
| "Role %s not supported by Google VertexAI ChatCompletion. Supported roles: '%s', '%s'", | ||
| messageRole, | ||
| USER_ROLE, | ||
| MODEL_ROLE | ||
| ) | ||
| ); | ||
|
|
||
| } | ||
|
|
||
| private void buildContents(XContentBuilder builder) throws IOException { | ||
| var messages = unifiedChatInput.getRequest().messages(); | ||
|
|
||
| builder.startArray(CONTENTS); | ||
| for (UnifiedCompletionRequest.Message message : messages) { | ||
| builder.startObject(); | ||
| builder.field(ROLE, messageRoleToGoogleVertexAiSupportedRole(message.role())); | ||
| builder.startArray(PARTS); | ||
| builder.startObject(); | ||
| builder.field(TEXT, message.content().toString()); | ||
|
||
| builder.endObject(); | ||
| builder.endArray(); | ||
| builder.endObject(); | ||
| } | ||
| builder.endArray(); | ||
| } | ||
|
|
||
| private void buildGenerationConfig(XContentBuilder builder) throws IOException { | ||
| var request = unifiedChatInput.getRequest(); | ||
|
|
||
| boolean hasAnyConfig = request.stop() != null | ||
| || request.temperature() != null | ||
| || request.maxCompletionTokens() != null | ||
| || request.topP() != null; | ||
|
|
||
| if (hasAnyConfig == false) { | ||
| return; | ||
| } | ||
|
|
||
| builder.startObject(GENERATION_CONFIG); | ||
|
|
||
| if (request.stop() != null) { | ||
| builder.stringListField(STOP_SEQUENCES, request.stop()); | ||
| } | ||
| if (request.temperature() != null) { | ||
| builder.field(TEMPERATURE, request.temperature()); | ||
| } | ||
| if (request.maxCompletionTokens() != null) { | ||
| builder.field(MAX_OUTPUT_TOKENS, request.maxCompletionTokens()); | ||
| } | ||
| if (request.topP() != null) { | ||
| builder.field(TOP_P, request.topP()); | ||
| } | ||
|
|
||
| builder.endObject(); | ||
| } | ||
|
|
||
| @Override | ||
| public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
| builder.startObject(); | ||
|
|
||
| buildContents(builder); | ||
| buildGenerationConfig(builder); | ||
|
|
||
| builder.endObject(); | ||
| return builder; | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.