Skip to content

Commit e2ed5cc

Browse files
author
Max Hniebergall
committed
Merge branch 'ml-inference-unified-api-elastic' of github.com:elastic/elasticsearch into ml-inference-unified-api-elastic
2 parents 357277e + 10a5b12 commit e2ed5cc

File tree

9 files changed

+160
-50
lines changed

9 files changed

+160
-50
lines changed

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ public void writeTo(StreamOutput out) throws IOException {
203203
out.writeString(type);
204204
}
205205

206+
public String toString() {
207+
return text + ":" + type;
208+
}
209+
206210
}
207211

208212
public record ContentString(String content) implements Content, NamedWriteable {
@@ -230,6 +234,10 @@ public String getWriteableName() {
230234
public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
231235
builder.value(content);
232236
}
237+
238+
public String toString() {
239+
return content;
240+
}
233241
}
234242

235243
public record ToolCall(String id, FunctionField function, String type) implements Writeable {
@@ -390,7 +398,7 @@ public void writeTo(StreamOutput out) throws IOException {
390398
public record FunctionField(
391399
@Nullable String description,
392400
String name,
393-
@Nullable Map<String, Object> parameters, // TODO can we parse this as a string?
401+
@Nullable Map<String, Object> parameters,
394402
@Nullable Boolean strict
395403
) implements Writeable {
396404

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase {
2323

24-
public void testResultstoXContentChunked() throws IOException {
24+
public void testResults_toXContentChunked() throws IOException {
2525
String expected = """
2626
{
2727
"id": "chunk1",

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import org.elasticsearch.test.cluster.ElasticsearchCluster;
2222
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
2323
import org.elasticsearch.test.rest.ESRestTestCase;
24+
import org.elasticsearch.xcontent.XContentBuilder;
25+
import org.elasticsearch.xcontent.XContentFactory;
26+
import org.elasticsearch.xcontent.XContentType;
2427
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
2528
import org.junit.ClassRule;
2629

@@ -341,10 +344,21 @@ protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskTy
341344
return callAsync(endpoint, input);
342345
}
343346

347+
protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List<String> input)
348+
throws Exception {
349+
var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId);
350+
return callAsyncUnified(endpoint, input, "user");
351+
}
352+
344353
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
345-
var responseConsumer = new AsyncInferenceResponseConsumer();
346354
var request = new Request("POST", endpoint);
347355
request.setJsonEntity(jsonBody(input));
356+
357+
return execAsyncCall(request);
358+
}
359+
360+
private Deque<ServerSentEvent> execAsyncCall(Request request) throws Exception {
361+
var responseConsumer = new AsyncInferenceResponseConsumer();
348362
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
349363
var latch = new CountDownLatch(1);
350364
client().performRequestAsync(request, new ResponseListener() {
@@ -362,6 +376,22 @@ public void onFailure(Exception exception) {
362376
return responseConsumer.events();
363377
}
364378

379+
private Deque<ServerSentEvent> callAsyncUnified(String endpoint, List<String> input, String role) throws Exception {
380+
var request = new Request("POST", endpoint);
381+
382+
request.setJsonEntity(createUnifiedJsonBody(input, role));
383+
return execAsyncCall(request);
384+
}
385+
386+
private String createUnifiedJsonBody(List<String> input, String role) throws IOException {
387+
var messages = input.stream().map(i -> Map.of("content", i, "role", role)).toList();
388+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
389+
builder.startObject();
390+
builder.field("messages", messages);
391+
builder.endObject();
392+
return org.elasticsearch.common.Strings.toString(builder);
393+
}
394+
365395
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
366396
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
367397
return inferInternal(endpoint, input, Map.of());

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@
1111

1212
import org.apache.http.util.EntityUtils;
1313
import org.elasticsearch.client.ResponseException;
14+
import org.elasticsearch.common.Strings;
1415
import org.elasticsearch.common.settings.Settings;
1516
import org.elasticsearch.inference.TaskType;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
import org.elasticsearch.xcontent.XContentFactory;
19+
import org.elasticsearch.xcontent.XContentType;
1620
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
1721

1822
import java.io.IOException;
1923
import java.util.ArrayList;
2024
import java.util.Arrays;
25+
import java.util.Iterator;
2126
import java.util.List;
2227
import java.util.Map;
2328
import java.util.Objects;
@@ -481,6 +486,56 @@ public void testSupportedStream() throws Exception {
481486
}
482487
}
483488

489+
public void testUnifiedCompletionInference() throws Exception {
490+
String modelId = "streaming";
491+
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION));
492+
var singleModel = getModel(modelId);
493+
assertEquals(modelId, singleModel.get("inference_id"));
494+
assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type"));
495+
496+
var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomUUID()).toList();
497+
try {
498+
var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input);
499+
var expectedResponses = expectedResultsIterator(input);
500+
assertThat(events.size(), equalTo((input.size() + 1) * 2));
501+
events.forEach(event -> {
502+
switch (event.name()) {
503+
case EVENT -> assertThat(event.value(), equalToIgnoringCase("message"));
504+
case DATA -> assertThat(event.value(), equalTo(expectedResponses.next()));
505+
}
506+
});
507+
} finally {
508+
deleteModel(modelId);
509+
}
510+
}
511+
512+
private static Iterator<String> expectedResultsIterator(List<String> input) {
513+
return Stream.concat(input.stream().map(String::toUpperCase).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]")).iterator();
514+
}
515+
516+
private static String expectedResult(String input) {
517+
try {
518+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
519+
builder.startObject();
520+
builder.field("id", "id");
521+
builder.startArray("choices");
522+
builder.startObject();
523+
builder.startObject("delta");
524+
builder.field("content", input);
525+
builder.endObject();
526+
builder.field("index", 0);
527+
builder.endObject();
528+
builder.endArray();
529+
builder.field("model", "gpt-4o-2024-08-06");
530+
builder.field("object", "chat.completion.chunk");
531+
builder.endObject();
532+
533+
return Strings.toString(builder);
534+
} catch (IOException e) {
535+
throw new RuntimeException(e);
536+
}
537+
}
538+
484539
public void testGetZeroModels() throws IOException {
485540
var models = getModels("_all", TaskType.RERANK);
486541
assertThat(models, empty());

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.xcontent.ToXContentObject;
3838
import org.elasticsearch.xcontent.XContentBuilder;
3939
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
40+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
4041

4142
import java.io.IOException;
4243
import java.util.EnumSet;
@@ -129,7 +130,15 @@ public void unifiedCompletionInfer(
129130
TimeValue timeout,
130131
ActionListener<InferenceServiceResults> listener
131132
) {
132-
listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); // TODO
133+
switch (model.getConfigurations().getTaskType()) {
134+
case COMPLETION -> listener.onResponse(makeUnifiedResults(request));
135+
default -> listener.onFailure(
136+
new ElasticsearchStatusException(
137+
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
138+
RestStatus.BAD_REQUEST
139+
)
140+
);
141+
}
133142
}
134143

135144
private StreamingChatCompletionResults makeResults(List<String> input) {
@@ -163,6 +172,59 @@ private ChunkedToXContent completionChunk(String delta) {
163172
);
164173
}
165174

175+
private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) {
176+
var responseIter = request.messages().stream().map(message -> message.content().toString().toUpperCase()).iterator();
177+
return new StreamingUnifiedChatCompletionResults(subscriber -> {
178+
subscriber.onSubscribe(new Flow.Subscription() {
179+
@Override
180+
public void request(long n) {
181+
if (responseIter.hasNext()) {
182+
subscriber.onNext(unifiedCompletionChunk(responseIter.next()));
183+
} else {
184+
subscriber.onComplete();
185+
}
186+
}
187+
188+
@Override
189+
public void cancel() {}
190+
});
191+
});
192+
}
193+
194+
/*
195+
The response format looks like this
196+
{
197+
"id": "chatcmpl-AarrzyuRflye7yzDF4lmVnenGmQCF",
198+
"choices": [
199+
{
200+
"delta": {
201+
"content": " information"
202+
},
203+
"index": 0
204+
}
205+
],
206+
"model": "gpt-4o-2024-08-06",
207+
"object": "chat.completion.chunk"
208+
}
209+
*/
210+
private ChunkedToXContent unifiedCompletionChunk(String delta) {
211+
return params -> Iterators.concat(
212+
ChunkedToXContentHelper.startObject(),
213+
ChunkedToXContentHelper.field("id", "id"),
214+
ChunkedToXContentHelper.startArray("choices"),
215+
ChunkedToXContentHelper.startObject(),
216+
ChunkedToXContentHelper.startObject("delta"),
217+
ChunkedToXContentHelper.field("content", delta),
218+
ChunkedToXContentHelper.endObject(),
219+
ChunkedToXContentHelper.field("index", 0),
220+
ChunkedToXContentHelper.endObject(),
221+
ChunkedToXContentHelper.endArray(),
222+
ChunkedToXContentHelper.field("model", "gpt-4o-2024-08-06"),
223+
ChunkedToXContentHelper.field("object", "chat.completion.chunk"),
224+
ChunkedToXContentHelper.endObject()
225+
);
226+
}
227+
166228
@Override
167229
public void chunkedInfer(
168230
Model model,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -92,28 +92,6 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
9292
return;
9393
}
9494

95-
// if (service.isEmpty()) {
96-
// var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
97-
// recordMetrics(unparsedModel, timer, e);
98-
// listener.onFailure(e);
99-
// return;
100-
// }
101-
102-
// if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
103-
// // not the wildcard task type and not the model task type
104-
// var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
105-
// recordMetrics(unparsedModel, timer, e);
106-
// listener.onFailure(e);
107-
// return;
108-
// }
109-
110-
// if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) {
111-
// var e = createInvalidTaskTypeException(request, unparsedModel);
112-
// recordMetrics(unparsedModel, timer, e);
113-
// listener.onFailure(e);
114-
// return;
115-
// }
116-
11795
var model = service.get()
11896
.parsePersistedConfigWithSecrets(
11997
unparsedModel.inferenceEntityId(),
@@ -195,27 +173,6 @@ private void inferOnService(Model model, Request request, InferenceService servi
195173
}
196174
}
197175

198-
// private static Runnable inferRunnable(
199-
// Model model,
200-
// T request,
201-
// InferenceService service,
202-
// ActionListener<InferenceServiceResults> listener
203-
// ) {
204-
// return request.isUnifiedCompletionMode()
205-
// // TODO add parameters
206-
// ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener)
207-
// : () -> service.infer(
208-
// model,
209-
// request.getQuery(),
210-
// request.getInput(),
211-
// request.isStreaming(),
212-
// request.getTaskSettings(),
213-
// request.getInputType(),
214-
// request.getInferenceTimeout(),
215-
// listener
216-
// );
217-
// }
218-
219176
protected abstract void doInference(
220177
Model model,
221178
Request request,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, ChunkedToXContent> {
3737
public static final String FUNCTION_FIELD = "function";
3838
private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class);
39-
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response";
4039

4140
private static final String CHOICES_FIELD = "choices";
4241
private static final String DELTA_FIELD = "delta";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> rec
358358
target.write(ServerSentEventSpec.EOL);
359359
target.write(ServerSentEventSpec.EOL);
360360
target.flush();
361-
362361
}
363362
final var result = new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException(chunkStream));
364363
target = null;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ public void testInfer_SendsRequest() throws IOException {
922922
}
923923

924924
public void testUnifiedCompletionInfer() throws Exception {
925-
// streaming response must be on a single line
925+
// The escapes are because the streaming response must be on a single line
926926
String responseJson = """
927927
data: {\
928928
"id":"12345",\

0 commit comments

Comments
 (0)