Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d429a31
VertexAI chat completion response entity with tests
lhoet-google Apr 29, 2025
00bfdb0
Modified build gradle to include google vertexai sdk
lhoet-google Apr 29, 2025
2378270
Google vertex ai chat completion model with tests
lhoet-google Apr 29, 2025
1f00974
Google vertex ai chat completion request with tests
lhoet-google Apr 30, 2025
970ab3c
TransportVersion
lhoet-google Apr 30, 2025
5428074
ChatCompletion TaskSettings & ServiceSettings
lhoet-google Apr 30, 2025
ee44f22
ChatCompletionRequestManager & tests
lhoet-google Apr 30, 2025
8160c2b
VertexAI Service and related classes. WIP & missing tests
lhoet-google Apr 30, 2025
ff68fbe
VertexAi ChatCompletion task settings fix.
lhoet-google May 5, 2025
29c7093
JsonArrayParts event processor & parser
lhoet-google May 6, 2025
bfd75b0
AI Service and service tests
lhoet-google May 6, 2025
2ebfac9
Unified chat completion response and request handlers. Also working w…
lhoet-google May 6, 2025
679ea80
StreamingProcessor now support tools. Added more tests
lhoet-google May 8, 2025
e611cc3
More tests for streaming processor
lhoet-google May 8, 2025
87e428a
Request entity tests
lhoet-google May 12, 2025
193d06d
Google vertexai unified chat completion entity now accepting tools an…
lhoet-google May 12, 2025
813a2e8
Serializing function call message
lhoet-google May 12, 2025
f1ab8cc
Response handler with tests
lhoet-google May 12, 2025
23c7d92
VertexAI chat completion req entity bugfixes
lhoet-google May 13, 2025
c45d23f
Bugfix in vertex ai unified chat completion req entity
lhoet-google May 13, 2025
a820d83
Bugfix in vertex ai unified streaming processor
lhoet-google May 13, 2025
d2f09cf
Removed google aiplatform sdk
lhoet-google May 13, 2025
bda94de
Renamed file to match class name for JsonArrayPartsEventParser
lhoet-google May 13, 2025
5dee072
Updated rate limit settings for vertex ai
lhoet-google May 13, 2025
2f75788
Deleted GoogleVertexAiChatCompletionTaskSettings
lhoet-google May 13, 2025
b50c911
VertexAI Unified chat completion request tests
lhoet-google May 14, 2025
d6ae90f
Fixed some tests
lhoet-google May 14, 2025
cbb387f
Fixed GoogleAIService get configuration tests
lhoet-google May 14, 2025
7e1c970
GoogleVertexAiCompletion action tests
lhoet-google May 14, 2025
5ab716f
Formatting
lhoet-google May 15, 2025
28aa464
Code style fix
lhoet-google May 15, 2025
2279391
Removed unnused variables
lhoet-google May 15, 2025
85af5c0
Function call id fixed
lhoet-google May 15, 2025
16c01b0
Bugfix
lhoet-google May 15, 2025
1732244
Merge branch 'main' into vertexai-chatcompletion
lhoet-google May 16, 2025
6cc165b
Testfix
lhoet-google May 16, 2025
c020122
Unit tests
beltrangs May 16, 2025
7821d58
Merge branch 'vertexai-chatcompletion' into google-chat-completion-tests
beltrangs May 16, 2025
8633659
Update ElasticInferenceServiceTests.java
beltrangs May 16, 2025
06020cc
Update GoogleVertexAiServiceTests.java
beltrangs May 16, 2025
5a2cfe5
Merge pull request #2 from beltrangslilly/google-chat-completion-tests
leo-hoet May 16, 2025
0cf1f3f
Merge branch 'vertexai-chatcompletion' of github.com:lhoet-google/ela…
lhoet-google May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@
<sha256 value="26752413f76b8391dacefff40db867c1d33d0bf63d32954de3e9bb74cdcb8568" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.google.api" name="gax" version="2.64.2">
<artifact name="gax-2.64.2.jar">
<sha256 value="e3d4bdffef341e2a8ec0bf045dd520997cb6124c74388fe53101689c63e71a1f" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.google.api" name="gax-httpjson" version="0.105.1">
<artifact name="gax-httpjson-0.105.1.jar">
<sha256 value="4b7e1135eb4a97bce9d9d8c56128c5c30594dc2bebf26c9851ac582d2b43b2db" origin="Generated by Gradle"/>
Expand Down Expand Up @@ -561,6 +566,11 @@
<sha256 value="6a72ec2bb2350ca1970019e388d00808136e4da2e30296e9d8c346e3850b0eaa" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.google.cloud" name="google-cloud-aiplatform" version="3.61.0">
<artifact name="google-cloud-aiplatform-3.61.0.jar">
<sha256 value="2e3c5fa48c3b8b750b99b7f18d6d69c5c1d021d550f7325ffac968c0e5dd3339" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.google.cloud" name="google-cloud-core" version="2.53.1">
<artifact name="google-cloud-core-2.53.1.jar">
<sha256 value="58e008f119a7aaf68d2d13f530e997db6797b7aaa70e08c563421627bed382b0" origin="Generated by Gradle"/>
Expand Down
3 changes: 3 additions & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ dependencies {
/* SLF4J (via AWS SDKv2) */
api "org.slf4j:slf4j-api:${versions.slf4j}"
runtimeOnly "org.slf4j:slf4j-nop:${versions.slf4j}"
/* Google aiplatform SDK */
implementation 'com.google.cloud:google-cloud-aiplatform:3.61.0'
api "com.google.api:gax:2.64.2"
}

tasks.named("dependencyLicenses").configure {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.completion;

import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings;

import java.net.URISyntaxException;
import java.util.Map;
import java.net.URI;
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel {
public GoogleVertexAiChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
GoogleVertexAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
GoogleVertexAiChatCompletionTaskSettings.fromMap(taskSettings),
GoogleVertexAiSecretSettings.fromMap(secrets)
);
}

GoogleVertexAiChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
GoogleVertexAiChatCompletionServiceSettings serviceSettings,
GoogleVertexAiChatCompletionTaskSettings taskSettings,
@Nullable GoogleVertexAiSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelSecrets(secrets),
serviceSettings
);
try {
this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

public static GoogleVertexAiChatCompletionModel of(GoogleVertexAiChatCompletionModel model, UnifiedCompletionRequest request) {
var originalModelServiceSettings = model.getServiceSettings();

var newServiceSettings = new GoogleVertexAiChatCompletionServiceSettings(
originalModelServiceSettings.projectId(),
originalModelServiceSettings.location(),
Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()),
originalModelServiceSettings.rateLimitSettings()
);

return new GoogleVertexAiChatCompletionModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
newServiceSettings,
model.getTaskSettings(),
model.getSecretSettings()
);
}

public GoogleVertexAiChatCompletionModel(
ModelConfigurations configurations,
ModelSecrets secrets,
GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings
) {
super(configurations, secrets, rateLimitServiceSettings);
}

@Override
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this, taskSettings);
}

@Override
public GoogleDiscoveryEngineRateLimitServiceSettings rateLimitServiceSettings() {
return (GoogleDiscoveryEngineRateLimitServiceSettings) super.rateLimitServiceSettings();
}

@Override
public GoogleVertexAiChatCompletionServiceSettings getServiceSettings() {
return (GoogleVertexAiChatCompletionServiceSettings) super.getServiceSettings();
}

@Override
public GoogleVertexAiChatCompletionTaskSettings getTaskSettings() {
return (GoogleVertexAiChatCompletionTaskSettings) super.getTaskSettings();
}

@Override
public GoogleVertexAiSecretSettings getSecretSettings() {
return (GoogleVertexAiSecretSettings) super.getSecretSettings();
}

public static URI buildUri(String location, String projectId, String model) throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
.setPathSegments(
GoogleVertexAiUtils.V1,
GoogleVertexAiUtils.PROJECTS,
projectId,
GoogleVertexAiUtils.LOCATIONS,
GoogleVertexAiUtils.GLOBAL,
GoogleVertexAiUtils.PUBLISHERS,
GoogleVertexAiUtils.PUBLISHER_GOOGLE,
GoogleVertexAiUtils.MODELS,
format("%s:%s", model, GoogleVertexAiUtils.STREAM_GENERATE_CONTENT)
)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.request;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class GoogleVertexAiUnifiedChatCompletionRequest implements GoogleVertexAiRequest {

private final GoogleVertexAiChatCompletionModel model;
private final UnifiedChatInput unifiedChatInput;

public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) {
this.model = Objects.requireNonNull(model);
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model);

ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());

decorateWithAuth(httpPost);
return new HttpRequest(httpPost, getInferenceEntityId());
}

public void decorateWithAuth(HttpPost httpPost) {
GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings());
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
// No truncation for Google VertexAI Chat completions
return this;
}

@Override
public boolean[] getTruncationInfo() {
// No truncation for Google VertexAI Chat completions
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai.request;

import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;

public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXContentObject {
// Field names matching the Google Vertex AI API structure
private static final String CONTENTS = "contents";
private static final String ROLE = "role";
private static final String PARTS = "parts";
private static final String TEXT = "text";
private static final String GENERATION_CONFIG = "generationConfig";
private static final String TEMPERATURE = "temperature";
private static final String MAX_OUTPUT_TOKENS = "maxOutputTokens";
private static final String TOP_P = "topP";
// TODO: Add other generationConfig fields if needed (e.g., stopSequences, topK)

private final UnifiedChatInput unifiedChatInput;
private final GoogleVertexAiChatCompletionModel model;

private static final String USER_ROLE = "user";
private static final String MODEL_ROLE = "model";
private static final String STOP_SEQUENCES = "stopSequences";

public GoogleVertexAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model); // Keep the model reference
}

private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) throws IOException {
var messageRoleLowered = messageRole.toLowerCase();

if (messageRoleLowered.equals(USER_ROLE) || messageRoleLowered.equals(MODEL_ROLE)) {
return messageRoleLowered;
}

// TODO: Here is OK to throw an IOException?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better as an ElasticsearchStatusException with RestStatus.BAD_REQUEST since it is an unsupported configuration that the user has to take action on. Preferably, this is validated within GoogleVertexAiService but I'm okay with it being this late in the call chain as well

throw new IOException(
format(
"Role %s not supported by Google VertexAI ChatCompletion. Supported roles: '%s', '%s'",
messageRole,
USER_ROLE,
MODEL_ROLE
)
);

}

private void buildContents(XContentBuilder builder) throws IOException {
var messages = unifiedChatInput.getRequest().messages();

builder.startArray(CONTENTS);
for (UnifiedCompletionRequest.Message message : messages) {
builder.startObject();
builder.field(ROLE, messageRoleToGoogleVertexAiSupportedRole(message.role()));
builder.startArray(PARTS);
builder.startObject();
builder.field(TEXT, message.content().toString());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

builder.endObject();
builder.endArray();
builder.endObject();
}
builder.endArray();
}

private void buildGenerationConfig(XContentBuilder builder) throws IOException {
var request = unifiedChatInput.getRequest();

boolean hasAnyConfig = request.stop() != null
|| request.temperature() != null
|| request.maxCompletionTokens() != null
|| request.topP() != null;

if (hasAnyConfig == false) {
return;
}

builder.startObject(GENERATION_CONFIG);

if (request.stop() != null) {
builder.stringListField(STOP_SEQUENCES, request.stop());
}
if (request.temperature() != null) {
builder.field(TEMPERATURE, request.temperature());
}
if (request.maxCompletionTokens() != null) {
builder.field(MAX_OUTPUT_TOKENS, request.maxCompletionTokens());
}
if (request.topP() != null) {
builder.field(TOP_P, request.topP());
}

builder.endObject();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

buildContents(builder);
buildGenerationConfig(builder);

builder.endObject();
return builder;
}
}
Loading