Skip to content

Commit d0d7b7c

Browse files
committed
Add new transport version
1 parent bd9aa11 commit d0d7b7c

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ static TransportVersion def(int id) {
172172
public static final TransportVersion INTRODUCE_FAILURES_LIFECYCLE_BACKPORT_8_19 = def(8_841_0_25);
173173
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
174174
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
175+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_28);
175176
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
176177
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
177178
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -248,6 +249,7 @@ static TransportVersion def(int id) {
248249
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION = def(9_071_0_00);
249250
public static final TransportVersion FILE_SETTINGS_HEALTH_INFO = def(9_072_0_00);
250251
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
252+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_074_0_00);
251253

252254
/*
253255
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettings.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
2424

2525
record SageMakerOpenAiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema {
26-
static final String NAME = "sagemaker_openai_text_embeddings_task_settings"; // also used for completion and chat completion
26+
static final String NAME = "sagemaker_openai_task_settings";
2727
private static final String USER_FIELD = "user";
2828

2929
SageMakerOpenAiTaskSettings(StreamInput in) throws IOException {
@@ -37,7 +37,7 @@ public String getWriteableName() {
3737

3838
@Override
3939
public TransportVersion getMinimalSupportedVersion() {
40-
return TransportVersions.ML_INFERENCE_SAGEMAKER;
40+
return TransportVersions.ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION;
4141
}
4242

4343
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
1212

1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.xcontent.XContentHelper;
1415
import org.elasticsearch.inference.InputType;
1516
import org.elasticsearch.inference.TaskType;
1617
import org.elasticsearch.inference.UnifiedCompletionRequest;
@@ -172,8 +173,8 @@ public void testResponse() throws Exception {
172173
}
173174

174175
public void testStreamResponse() throws Exception {
175-
var responseJson = """
176-
data: {
176+
var responseJson = dataPayload("""
177+
{
177178
"id":"12345",
178179
"object":"chat.completion.chunk",
179180
"created":123456789,
@@ -190,14 +191,18 @@ public void testStreamResponse() throws Exception {
190191
}
191192
]
192193
}
193-
""".replaceAll("\\s+", "").replaceAll("\\n+", "") + "\n\n";
194+
""");
194195

195-
var streamingResults = payload.streamResponseBody(mockModel(), SdkBytes.fromUtf8String(responseJson));
196+
var streamingResults = payload.streamResponseBody(mockModel(), responseJson);
196197

197198
assertThat(streamingResults.results().size(), is(1));
198199
assertThat(streamingResults.results().iterator().next().delta(), is("test"));
199200
}
200201

202+
private SdkBytes dataPayload(String json) throws IOException {
203+
return SdkBytes.fromUtf8String("data: " + XContentHelper.stripWhitespace(json) + "\n\n");
204+
}
205+
201206
private SageMakerModel mockModel() {
202207
SageMakerModel model = mock();
203208
when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings());
@@ -260,12 +265,9 @@ public void testChatCompletionResponse() throws Exception {
260265
"total_tokens": 15
261266
}
262267
}
263-
""".replaceAll("\\s+", "").replaceAll("\\n+", "");
268+
""";
264269

265-
var chatCompletionResponse = payload.chatCompletionResponseBody(
266-
mockModel(),
267-
SdkBytes.fromUtf8String("data:" + responseJson + "\n\n")
268-
);
270+
var chatCompletionResponse = payload.chatCompletionResponseBody(mockModel(), dataPayload(responseJson));
269271

270272
XContentBuilder builder = JsonXContent.contentBuilder();
271273
chatCompletionResponse.toXContentChunked(null).forEachRemaining(xContent -> {
@@ -276,6 +278,6 @@ public void testChatCompletionResponse() throws Exception {
276278
}
277279
});
278280

279-
assertEquals(responseJson, Strings.toString(builder).trim());
281+
assertEquals(XContentHelper.stripWhitespace(responseJson), Strings.toString(builder).trim());
280282
}
281283
}

0 commit comments

Comments
 (0)