Skip to content

Commit 933087d

Browse files
author
Max Hniebergall
committed
Merge branch 'ml-inference-unified-api-elastic' of github.com:elastic/elasticsearch into ml-inference-unified-api-elastic
2 parents 3dfb8f5 + 0166d98 commit 933087d

File tree

10 files changed

+291
-26
lines changed

10 files changed

+291
-26
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ void unifiedCompletionInfer(
127127
);
128128

129129
/**
130-
* Chunk long text according to {@code chunkingOptions} or the
131-
* model defaults if {@code chunkingOptions} contains unset
132-
* values.
130+
* Chunk long text.
133131
*
134132
* @param model The model
135133
* @param query Inference query, mainly for re-ranking

server/src/main/java/org/elasticsearch/inference/TaskType.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ public static TaskType fromString(String name) {
3838
}
3939

4040
public static TaskType fromStringOrStatusException(String name) {
41+
if (name == null) {
42+
throw new ElasticsearchStatusException("Task type must not be null", RestStatus.BAD_REQUEST);
43+
}
44+
4145
try {
4246
TaskType taskType = TaskType.fromString(name);
4347
return Objects.requireNonNull(taskType);

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C
6363
PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages"));
6464
PARSER.declareString(optionalConstructorArg(), new ParseField("model"));
6565
PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens"));
66-
// PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"),
67-
// ObjectParser.ValueType.VALUE_ARRAY);
6866
PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop"));
6967
PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature"));
7068
PARSER.declareField(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
4343
private static final String TOOL_FIELD = "tools";
4444
private static final String TEXT_FIELD = "text";
4545
private static final String TYPE_FIELD = "type";
46+
private static final String STREAM_OPTIONS_FIELD = "stream_options";
47+
private static final String INCLUDE_USAGE_FIELD = "include_usage";
4648

4749
private final UnifiedCompletionRequest unifiedRequest;
4850
private final boolean stream;
@@ -169,6 +171,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
169171
}
170172

171173
builder.field(STREAM_FIELD, stream);
174+
if (stream) {
175+
builder.startObject(STREAM_OPTIONS_FIELD);
176+
builder.field(INCLUDE_USAGE_FIELD, true);
177+
builder.endObject();
178+
}
172179
builder.endObject();
173180

174181
return builder;
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.hamcrest.Matchers;
14+
15+
public class TaskTypeTests extends ESTestCase {
16+
17+
public void testFromStringOrStatusException() {
18+
var exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException(null));
19+
assertThat(exception.getMessage(), Matchers.is("Task type must not be null"));
20+
21+
exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException("blah"));
22+
assertThat(exception.getMessage(), Matchers.is("Unknown task_type [blah]"));
23+
24+
assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY));
25+
}
26+
27+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase {
3131

3232
public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException {
33-
var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user");
33+
var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user", true);
3434
var httpRequest = request.createHttpRequest();
3535

3636
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -42,16 +42,27 @@ public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOExceptio
4242
assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org"));
4343

4444
var requestMap = entityAsMap(httpPost.getEntity().getContent());
45-
assertThat(requestMap, aMapWithSize(5));
45+
assertRequestMapWithUser(requestMap, "user");
46+
}
47+
48+
private void assertRequestMapWithoutUser(Map<String, Object> requestMap) {
49+
assertRequestMapWithUser(requestMap, null);
50+
}
51+
52+
private void assertRequestMapWithUser(Map<String, Object> requestMap, @Nullable String user) {
53+
assertThat(requestMap, aMapWithSize(user != null ? 6 : 5));
4654
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
4755
assertThat(requestMap.get("model"), is("model"));
48-
assertThat(requestMap.get("user"), is("user"));
56+
if (user != null) {
57+
assertThat(requestMap.get("user"), is(user));
58+
}
4959
assertThat(requestMap.get("n"), is(1));
5060
assertTrue((Boolean) requestMap.get("stream"));
61+
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
5162
}
5263

5364
public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOException {
54-
var request = createRequest(null, "org", "secret", "abc", "model", "user");
65+
var request = createRequest(null, "org", "secret", "abc", "model", "user", true);
5566
var httpRequest = request.createHttpRequest();
5667

5768
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -63,16 +74,12 @@ public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOExce
6374
assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org"));
6475

6576
var requestMap = entityAsMap(httpPost.getEntity().getContent());
66-
assertThat(requestMap, aMapWithSize(5));
67-
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
68-
assertThat(requestMap.get("model"), is("model"));
69-
assertThat(requestMap.get("user"), is("user"));
70-
assertThat(requestMap.get("n"), is(1));
71-
assertTrue((Boolean) requestMap.get("stream"));
77+
assertRequestMapWithUser(requestMap, "user");
78+
7279
}
7380

7481
public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws URISyntaxException, IOException {
75-
var request = createRequest(null, null, "secret", "abc", "model", null);
82+
var request = createRequest(null, null, "secret", "abc", "model", null, true);
7683
var httpRequest = request.createHttpRequest();
7784

7885
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -84,14 +91,10 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws
8491
assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER));
8592

8693
var requestMap = entityAsMap(httpPost.getEntity().getContent());
87-
assertThat(requestMap, aMapWithSize(4));
88-
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
89-
assertThat(requestMap.get("model"), is("model"));
90-
assertThat(requestMap.get("n"), is(1));
91-
assertTrue((Boolean) requestMap.get("stream"));
94+
assertRequestMapWithoutUser(requestMap);
9295
}
9396

94-
public void testCreateRequest_WithStreaming() throws URISyntaxException, IOException {
97+
public void testCreateRequest_WithStreaming() throws IOException {
9598
var request = createRequest(null, null, "secret", "abc", "model", null, true);
9699
var httpRequest = request.createHttpRequest();
97100

@@ -103,7 +106,7 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep
103106
}
104107

105108
public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException {
106-
var request = createRequest(null, null, "secret", "abcd", "model", null);
109+
var request = createRequest(null, null, "secret", "abcd", "model", null, true);
107110
var truncatedRequest = request.truncate();
108111
assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString()));
109112

@@ -112,17 +115,18 @@ public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException,
112115

113116
var httpPost = (HttpPost) httpRequest.httpRequestBase();
114117
var requestMap = entityAsMap(httpPost.getEntity().getContent());
115-
assertThat(requestMap, aMapWithSize(4));
118+
assertThat(requestMap, aMapWithSize(5));
116119

117120
// We do not truncate for OpenAi chat completions
118121
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
119122
assertThat(requestMap.get("model"), is("model"));
120123
assertThat(requestMap.get("n"), is(1));
121124
assertTrue((Boolean) requestMap.get("stream"));
125+
assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true)));
122126
}
123127

124128
public void testTruncationInfo_ReturnsNull() {
125-
var request = createRequest(null, null, "secret", "abcd", "model", null);
129+
var request = createRequest(null, null, "secret", "abcd", "model", null, true);
126130
assertNull(request.getTruncationInfo());
127131
}
128132

@@ -147,7 +151,7 @@ public static OpenAiUnifiedChatCompletionRequest createRequest(
147151
boolean stream
148152
) {
149153
var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user);
150-
return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", true), chatCompletionModel);
154+
return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
151155
}
152156

153157
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
package org.elasticsearch.xpack.inference.rest;
99

1010
import org.apache.lucene.util.SetOnce;
11+
import org.elasticsearch.ElasticsearchStatusException;
1112
import org.elasticsearch.action.ActionListener;
1213
import org.elasticsearch.common.bytes.BytesArray;
1314
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.inference.TaskType;
1416
import org.elasticsearch.rest.RestChannel;
1517
import org.elasticsearch.rest.RestRequest;
18+
import org.elasticsearch.rest.RestRequestTests;
1619
import org.elasticsearch.rest.action.RestChunkedToXContentListener;
1720
import org.elasticsearch.test.rest.FakeRestRequest;
1821
import org.elasticsearch.test.rest.RestActionTestCase;
@@ -26,6 +29,10 @@
2629
import java.util.Map;
2730

2831
import static org.elasticsearch.rest.RestRequest.Method.POST;
32+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseParams;
33+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
34+
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
35+
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID;
2936
import static org.hamcrest.CoreMatchers.is;
3037
import static org.hamcrest.Matchers.equalTo;
3138
import static org.hamcrest.Matchers.instanceOf;
@@ -56,6 +63,42 @@ private static String route(String param) {
5663
return "_route/" + param;
5764
}
5865

66+
public void testParseParams_ExtractsInferenceIdAndTaskType() {
67+
var params = parseParams(
68+
RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id", TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString()))
69+
);
70+
assertThat(params, is(new BaseInferenceAction.Params("id", TaskType.COMPLETION)));
71+
}
72+
73+
public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() {
74+
var params = parseParams(
75+
RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString()))
76+
);
77+
assertThat(params, is(new BaseInferenceAction.Params("completion", TaskType.ANY)));
78+
}
79+
80+
public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() {
81+
var e = expectThrows(
82+
ElasticsearchStatusException.class,
83+
() -> parseParams(RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id")))
84+
);
85+
assertThat(e.getMessage(), is("Task type must not be null"));
86+
}
87+
88+
public void testParseTimeout_ReturnsTimeout() {
89+
var timeout = parseTimeout(
90+
RestRequestTests.contentRestRequest("{}", Map.of(InferenceAction.Request.TIMEOUT.getPreferredName(), "4s"))
91+
);
92+
93+
assertThat(timeout, is(TimeValue.timeValueSeconds(4)));
94+
}
95+
96+
public void testParseTimeout_ReturnsDefaultTimeout() {
97+
var timeout = parseTimeout(RestRequestTests.contentRestRequest("{}", Map.of()));
98+
99+
assertThat(timeout, is(TimeValue.timeValueSeconds(30)));
100+
}
101+
59102
public void testUsesDefaultTimeout() {
60103
SetOnce<Boolean> executeCalled = new SetOnce<>();
61104
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.rest;
9+
10+
import org.apache.lucene.util.SetOnce;
11+
import org.elasticsearch.common.bytes.BytesArray;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.common.util.concurrent.ThreadContext;
14+
import org.elasticsearch.rest.AbstractRestChannel;
15+
import org.elasticsearch.rest.RestChannel;
16+
import org.elasticsearch.rest.RestRequest;
17+
import org.elasticsearch.rest.RestResponse;
18+
import org.elasticsearch.test.rest.FakeRestRequest;
19+
import org.elasticsearch.test.rest.RestActionTestCase;
20+
import org.elasticsearch.xcontent.XContentType;
21+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
22+
import org.junit.Before;
23+
24+
import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse;
25+
import static org.hamcrest.CoreMatchers.is;
26+
import static org.hamcrest.Matchers.equalTo;
27+
import static org.hamcrest.Matchers.instanceOf;
28+
29+
public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase {
30+
31+
@Before
32+
public void setUpAction() {
33+
controller().registerHandler(new RestUnifiedCompletionInferenceAction());
34+
}
35+
36+
public void testStreamIsTrue() {
37+
SetOnce<Boolean> executeCalled = new SetOnce<>();
38+
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
39+
assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class));
40+
41+
var request = (UnifiedCompletionAction.Request) actionRequest;
42+
assertThat(request.isStreaming(), is(true));
43+
44+
executeCalled.set(true);
45+
return createResponse();
46+
}));
47+
48+
var requestBody = """
49+
{
50+
"messages": [
51+
{
52+
"content": "abc",
53+
"role": "user"
54+
}
55+
]
56+
}
57+
""";
58+
59+
RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
60+
.withPath("_inference/completion/test/_unified")
61+
.withContent(new BytesArray(requestBody), XContentType.JSON)
62+
.build();
63+
64+
final SetOnce<RestResponse> responseSetOnce = new SetOnce<>();
65+
dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) {
66+
@Override
67+
public void sendResponse(RestResponse response) {
68+
responseSetOnce.set(response);
69+
}
70+
});
71+
72+
// the response content will be null when there is no error
73+
assertNull(responseSetOnce.get().content());
74+
assertThat(executeCalled.get(), equalTo(true));
75+
}
76+
77+
private void dispatchRequest(final RestRequest request, final RestChannel channel) {
78+
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
79+
controller().dispatchRequest(request, channel, threadContext);
80+
}
81+
}

0 commit comments

Comments
 (0)