Skip to content

Commit d18da3c

Browse files
Jan-Kazlouski-elasticelasticsearchmachinejonathan-buttner
authored
Add Google Model Garden's Anthropic support to Inference Plugin (#134080)
* Add Google Model Garden Anthropic integration * Clean up AnthropicChatCompletionStreamingProcessor * Enhance GoogleVertexAiChatCompletionServiceSettings to support optional parameters based on transport version * Add extractOptionalUri method and corresponding tests for URI extraction * Add GoogleModelGardenProvider support to chat completion models and tests * Enhance AnthropicChatCompletionStreamingProcessor and related classes to support new content block types and improve parsing logic * Refactor AnthropicChatCompletionResponseHandler to use a custom error parser and add unit tests for response validation * Add unit tests for AnthropicChatCompletionStreamingProcessor to validate response parsing and error handling * Add unit tests for GoogleModelGardenAnthropicChatCompletionRequestEntity to validate serialization of user fields * Add support for Anthropic provider in Google Vertex AI chat completion model and update related tests * Add changelog * Refactor switch case in GoogleVertexAiActionCreator to handle null case * Validate service settings for Google Vertex AI model configuration * Enhance Anthropic model tests to validate URI handling and provider requirements * [CI] Auto commit changes from spotless * Refactor switch case in GoogleVertexAiService to handle null case * Simplify version check in GoogleVertexAiChatCompletionServiceSettings * Make GOOGLE provider default for GoogleModelGarden integration * Update anthropic_version to vertex-2024-10-22 in request entity and tests * Refactor Google Vertex AI request handling to improve provider management and error handling * Enhance validation for Google Model Garden settings to ensure required parameters are provided * Remove uri streamingUri and provider from rate limit grouping hash calculation * Refactor null and empty checks for projectId, location, and modelId in GoogleVertexAiChatCompletionServiceSettings * Refactor Google Model Garden integration to include task settings in request entity and enhance validation for max tokens * Revert "Update anthropic_version to vertex-2024-10-22 in request entity and tests" This reverts commit 63ea4b8. * Refactor Google Vertex AI settings to utilize GoogleVertexAiUtils for model garden support checks * [CI] Update transport version definitions * Update anthropic_version in tests and enhance validation logic for Google Vertex AI settings * Update versions * Enhance task settings validation in GoogleVertexAiChatCompletionModel * Address comments regarding anthropic version and configuration * [CI] Update transport version definitions * Add nullable annotation for maxTokens parameter in GoogleVertexAiChatCompletionTaskSettings * [CI] Update transport version definitions * Clarify URI handling logic in GoogleVertexAiChatCompletionModel comments * Make maxTokens nullable * [CI] Update transport version definitions * Fixed unit tests * [CI] Update transport version definitions * Fix validation logic for Google Model Garden and Vertex AI settings * [CI] Update transport version definitions * Add validation tests for Google Vertex AI and Model Garden settings * Refactor validation logic for Google Vertex AI and Model Garden settings * Add comment * Update Google Vertex AI Task Settings parsing logic and AnthropicChatCompletionStreamingProcessor readability --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]>
1 parent 64b8574 commit d18da3c

File tree

30 files changed

+1934
-118
lines changed

30 files changed

+1934
-118
lines changed

docs/changelog/134080.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 134080
2+
summary: Added Google Model Garden Anthropic Completion and Chat Completion support to the Inference Plugin
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ public record UnifiedCompletionRequest(
5858
private static final String ROLE_FIELD = "role";
5959
private static final String CONTENT_FIELD = "content";
6060
private static final String STOP_FIELD = "stop";
61-
private static final String TEMPERATURE_FIELD = "temperature";
62-
private static final String TOOL_CHOICE_FIELD = "tool_choice";
63-
private static final String TOOL_FIELD = "tools";
61+
public static final String TEMPERATURE_FIELD = "temperature";
62+
public static final String TOOL_CHOICE_FIELD = "tool_choice";
63+
public static final String TOOL_FIELD = "tools";
6464
private static final String TEXT_FIELD = "text";
65-
private static final String TYPE_FIELD = "type";
65+
public static final String TYPE_FIELD = "type";
6666
private static final String MODEL_FIELD = "model";
6767
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
6868
private static final String MAX_TOKENS_FIELD = "max_tokens";
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9179000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
no_matching_project_exception,9178000
1+
ml_inference_google_model_garden_added,9179000

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,18 @@ public static URI extractUri(Map<String, Object> map, String fieldName, Validati
314314
return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
315315
}
316316

317+
/**
318+
* Extracts an optional URI from the map. If the field is not present, null is returned. If the field is present but invalid,
319+
* @param map the map to extract the URI from
320+
* @param fieldName the field name to extract
321+
* @param validationException the validation exception to add errors to
322+
* @return the extracted URI or null if not present
323+
*/
324+
public static URI extractOptionalUri(Map<String, Object> map, String fieldName, ValidationException validationException) {
325+
String parsedUrl = extractOptionalString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
326+
return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException);
327+
}
328+
317329
public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) {
318330
try {
319331
return createOptionalUri(url);
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.anthropic;
9+
10+
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
12+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
13+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
14+
import org.elasticsearch.xpack.inference.external.http.retry.ChatCompletionErrorResponseHandler;
15+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
16+
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract;
17+
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils;
18+
import org.elasticsearch.xpack.inference.external.request.Request;
19+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
20+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
21+
import org.elasticsearch.xpack.inference.services.anthropic.response.AnthropicChatCompletionResponseEntity;
22+
23+
import java.util.concurrent.Flow;
24+
25+
/**
26+
* Handles streaming chat completion responses and error parsing for Anthropic inference endpoints.
27+
* Adapts the AnthropicResponseHandler to support chat completion schema.
28+
*/
29+
public class AnthropicChatCompletionResponseHandler extends AnthropicResponseHandler {
30+
private static final String ANTHROPIC_ERROR = "anthropic_error";
31+
private static final UnifiedChatCompletionErrorParserContract ANTHROPIC_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils
32+
.createErrorParserWithStringify(ANTHROPIC_ERROR);
33+
34+
private final ChatCompletionErrorResponseHandler chatCompletionErrorResponseHandler;
35+
36+
public AnthropicChatCompletionResponseHandler(String requestType) {
37+
this(requestType, AnthropicChatCompletionResponseEntity::fromResponse);
38+
}
39+
40+
private AnthropicChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
41+
super(requestType, parseFunction, true);
42+
this.chatCompletionErrorResponseHandler = new ChatCompletionErrorResponseHandler(ANTHROPIC_ERROR_PARSER);
43+
}
44+
45+
@Override
46+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
47+
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
48+
var anthropicProcessor = new AnthropicChatCompletionStreamingProcessor(
49+
(m, e) -> chatCompletionErrorResponseHandler.buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e)
50+
);
51+
flow.subscribe(serverSentEventProcessor);
52+
serverSentEventProcessor.subscribe(anthropicProcessor);
53+
return new StreamingUnifiedChatCompletionResults(anthropicProcessor);
54+
}
55+
56+
@Override
57+
protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result) {
58+
return chatCompletionErrorResponseHandler.buildChatCompletionError(message, request, result);
59+
}
60+
}

0 commit comments

Comments
 (0)