Skip to content

Commit 03fada0

Browse files
Separate unified and legacy code paths
1 parent 2660ecb commit 03fada0

23 files changed

+740
-255
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.Iterator;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.Objects;
2423
import java.util.concurrent.Flow;
2524

2625
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
@@ -78,102 +77,16 @@ public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params
7877
}
7978
}
8079

81-
public record Result(String delta, String refusal, List<ToolCall> toolCalls) implements ChunkedToXContent {
82-
80+
public record Result(String delta) implements ChunkedToXContent {
8381
private static final String RESULT = "delta";
84-
private static final String REFUSAL = "refusal";
85-
private static final String TOOL_CALLS = "tool_calls";
86-
87-
public Result(String delta) {
88-
this(delta, "", List.of());
89-
}
9082

9183
@Override
9284
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
9385
return Iterators.concat(
9486
ChunkedToXContentHelper.startObject(),
9587
ChunkedToXContentHelper.field(RESULT, delta),
96-
ChunkedToXContentHelper.field(REFUSAL, refusal),
97-
ChunkedToXContentHelper.startArray(TOOL_CALLS),
98-
Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)),
99-
ChunkedToXContentHelper.endArray(),
10088
ChunkedToXContentHelper.endObject()
10189
);
10290
}
10391
}
104-
105-
public static class ToolCall implements ChunkedToXContent {
106-
private final int index;
107-
private final String id;
108-
private final String functionName;
109-
private final String functionArguments;
110-
111-
public ToolCall(int index, String id, String functionName, String functionArguments) {
112-
this.index = index;
113-
this.id = id;
114-
this.functionName = functionName;
115-
this.functionArguments = functionArguments;
116-
}
117-
118-
public int getIndex() {
119-
return index;
120-
}
121-
122-
public String getId() {
123-
return id;
124-
}
125-
126-
public String getFunctionName() {
127-
return functionName;
128-
}
129-
130-
public String getFunctionArguments() {
131-
return functionArguments;
132-
}
133-
134-
@Override
135-
public boolean equals(Object o) {
136-
if (this == o) return true;
137-
if (o == null || getClass() != o.getClass()) return false;
138-
ToolCall toolCall = (ToolCall) o;
139-
return index == toolCall.index
140-
&& Objects.equals(id, toolCall.id)
141-
&& Objects.equals(functionName, toolCall.functionName)
142-
&& Objects.equals(functionArguments, toolCall.functionArguments);
143-
}
144-
145-
@Override
146-
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
147-
return Iterators.concat(
148-
ChunkedToXContentHelper.startObject(),
149-
ChunkedToXContentHelper.field("index", index),
150-
ChunkedToXContentHelper.field("id", id),
151-
ChunkedToXContentHelper.field("functionName", functionName),
152-
ChunkedToXContentHelper.field("functionArguments", functionArguments),
153-
ChunkedToXContentHelper.endObject()
154-
);
155-
}
156-
157-
@Override
158-
public int hashCode() {
159-
return Objects.hash(index, id, functionName, functionArguments);
160-
}
161-
162-
@Override
163-
public String toString() {
164-
return "ToolCall{"
165-
+ "index="
166-
+ index
167-
+ ", id='"
168-
+ id
169-
+ '\''
170-
+ ", functionName='"
171-
+ functionName
172-
+ '\''
173-
+ ", functionArguments='"
174-
+ functionArguments
175-
+ '\''
176-
+ '}';
177-
}
178-
}
17992
}
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.results;
9+
10+
import org.elasticsearch.common.collect.Iterators;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
13+
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
14+
import org.elasticsearch.inference.InferenceResults;
15+
import org.elasticsearch.inference.InferenceServiceResults;
16+
import org.elasticsearch.xcontent.ToXContent;
17+
18+
import java.io.IOException;
19+
import java.util.Deque;
20+
import java.util.Iterator;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.Objects;
24+
import java.util.concurrent.Flow;
25+
26+
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
27+
28+
/**
29+
* Chat Completion results that only contain a Flow.Publisher.
30+
*/
31+
public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher)
32+
implements
33+
InferenceServiceResults {
34+
35+
@Override
36+
public boolean isStreaming() {
37+
return true;
38+
}
39+
40+
@Override
41+
public List<? extends InferenceResults> transformToCoordinationFormat() {
42+
throw new UnsupportedOperationException("Not implemented");
43+
}
44+
45+
@Override
46+
public List<? extends InferenceResults> transformToLegacyFormat() {
47+
throw new UnsupportedOperationException("Not implemented");
48+
}
49+
50+
@Override
51+
public Map<String, Object> asMap() {
52+
throw new UnsupportedOperationException("Not implemented");
53+
}
54+
55+
@Override
56+
public String getWriteableName() {
57+
throw new UnsupportedOperationException("Not implemented");
58+
}
59+
60+
@Override
61+
public void writeTo(StreamOutput out) throws IOException {
62+
throw new UnsupportedOperationException("Not implemented");
63+
}
64+
65+
@Override
66+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
67+
throw new UnsupportedOperationException("Not implemented");
68+
}
69+
70+
public record Results(Deque<Result> results) implements ChunkedToXContent {
71+
@Override
72+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
73+
return Iterators.concat(
74+
ChunkedToXContentHelper.startObject(),
75+
ChunkedToXContentHelper.startArray(COMPLETION),
76+
Iterators.flatMap(results.iterator(), d -> d.toXContentChunked(params)),
77+
ChunkedToXContentHelper.endArray(),
78+
ChunkedToXContentHelper.endObject()
79+
);
80+
}
81+
}
82+
83+
public record Result(String delta, String refusal, List<ToolCall> toolCalls) implements ChunkedToXContent {
84+
85+
private static final String RESULT = "delta";
86+
private static final String REFUSAL = "refusal";
87+
private static final String TOOL_CALLS = "tool_calls";
88+
89+
public Result(String delta) {
90+
this(delta, "", List.of());
91+
}
92+
93+
@Override
94+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
95+
return Iterators.concat(
96+
ChunkedToXContentHelper.startObject(),
97+
ChunkedToXContentHelper.field(RESULT, delta),
98+
ChunkedToXContentHelper.field(REFUSAL, refusal),
99+
ChunkedToXContentHelper.startArray(TOOL_CALLS),
100+
Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)),
101+
ChunkedToXContentHelper.endArray(),
102+
ChunkedToXContentHelper.endObject()
103+
);
104+
}
105+
}
106+
107+
public static class ToolCall implements ChunkedToXContent {
108+
private final int index;
109+
private final String id;
110+
private final String functionName;
111+
private final String functionArguments;
112+
113+
public ToolCall(int index, String id, String functionName, String functionArguments) {
114+
this.index = index;
115+
this.id = id;
116+
this.functionName = functionName;
117+
this.functionArguments = functionArguments;
118+
}
119+
120+
public int getIndex() {
121+
return index;
122+
}
123+
124+
public String getId() {
125+
return id;
126+
}
127+
128+
public String getFunctionName() {
129+
return functionName;
130+
}
131+
132+
public String getFunctionArguments() {
133+
return functionArguments;
134+
}
135+
136+
@Override
137+
public boolean equals(Object o) {
138+
if (this == o) return true;
139+
if (o == null || getClass() != o.getClass()) return false;
140+
ToolCall toolCall = (ToolCall) o;
141+
return index == toolCall.index
142+
&& Objects.equals(id, toolCall.id)
143+
&& Objects.equals(functionName, toolCall.functionName)
144+
&& Objects.equals(functionArguments, toolCall.functionArguments);
145+
}
146+
147+
@Override
148+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
149+
return Iterators.concat(
150+
ChunkedToXContentHelper.startObject(),
151+
ChunkedToXContentHelper.field("index", index),
152+
ChunkedToXContentHelper.field("id", id),
153+
ChunkedToXContentHelper.field("functionName", functionName),
154+
ChunkedToXContentHelper.field("functionArguments", functionArguments),
155+
ChunkedToXContentHelper.endObject()
156+
);
157+
}
158+
159+
@Override
160+
public int hashCode() {
161+
return Objects.hash(index, id, functionName, functionArguments);
162+
}
163+
164+
@Override
165+
public String toString() {
166+
return "ToolCall{"
167+
+ "index="
168+
+ index
169+
+ ", id='"
170+
+ id
171+
+ '\''
172+
+ ", functionName='"
173+
+ functionName
174+
+ '\''
175+
+ ", functionArguments='"
176+
+ functionArguments
177+
+ '\''
178+
+ '}';
179+
}
180+
}
181+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void execute(
6969
Supplier<Boolean> hasRequestCompletedFunction,
7070
ActionListener<InferenceServiceResults> listener
7171
) {
72-
List<String> input = DocumentsOnlyInput.of(inferenceInputs).getInputs();
72+
List<String> input = inferenceInputs.castTo(ChatCompletionInput.class).getInputs();
7373
AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model);
7474
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
7575
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ public void execute(
4444
Supplier<Boolean> hasRequestCompletedFunction,
4545
ActionListener<InferenceServiceResults> listener
4646
) {
47-
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
48-
var docsInput = docsOnly.getInputs();
49-
var stream = docsOnly.stream();
50-
var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput);
47+
var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class);
48+
var inputs = chatCompletionInput.getInputs();
49+
var stream = chatCompletionInput.stream();
50+
var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs);
5151
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream);
5252
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();
5353

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public void execute(
4646
Supplier<Boolean> hasRequestCompletedFunction,
4747
ActionListener<InferenceServiceResults> listener
4848
) {
49-
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
50-
var docsInput = docsOnly.getInputs();
51-
var stream = docsOnly.stream();
52-
AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream);
49+
var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class);
50+
var inputs = chatCompletionInput.getInputs();
51+
var stream = chatCompletionInput.stream();
52+
AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(inputs, model, stream);
5353

5454
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5555
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ public void execute(
4141
Supplier<Boolean> hasRequestCompletedFunction,
4242
ActionListener<InferenceServiceResults> listener
4343
) {
44-
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
45-
var docsInput = docsOnly.getInputs();
46-
var stream = docsOnly.stream();
47-
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput, stream);
44+
var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class);
45+
var inputs = chatCompletionInput.getInputs();
46+
var stream = chatCompletionInput.stream();
47+
AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, inputs, stream);
4848

4949
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5050
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public void execute(
4646
Supplier<Boolean> hasRequestCompletedFunction,
4747
ActionListener<InferenceServiceResults> listener
4848
) {
49-
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
50-
var docsInput = docsOnly.getInputs();
51-
var stream = docsOnly.stream();
52-
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream);
49+
var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class);
50+
var inputs = chatCompletionInput.getInputs();
51+
var stream = chatCompletionInput.stream();
52+
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(inputs, model, stream);
5353
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
5454
}
5555

0 commit comments

Comments
 (0)