|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.inference.external.request.elastic; |
9 | 9 |
|
10 | | -import org.elasticsearch.inference.UnifiedCompletionRequest; |
11 | 10 | import org.elasticsearch.xcontent.ToXContentObject; |
12 | 11 | import org.elasticsearch.xcontent.XContentBuilder; |
13 | 12 | import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; |
| 13 | +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; |
14 | 14 |
|
15 | 15 | import java.io.IOException; |
16 | 16 | import java.util.Objects; |
17 | 17 |
|
18 | 18 | public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject { |
19 | | - |
20 | | - public static final String NAME_FIELD = "name"; |
21 | | - public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; |
22 | | - public static final String TOOL_CALLS_FIELD = "tool_calls"; |
23 | | - public static final String ID_FIELD = "id"; |
24 | | - public static final String FUNCTION_FIELD = "function"; |
25 | | - public static final String ARGUMENTS_FIELD = "arguments"; |
26 | | - public static final String DESCRIPTION_FIELD = "description"; |
27 | | - public static final String PARAMETERS_FIELD = "parameters"; |
28 | | - public static final String STRICT_FIELD = "strict"; |
29 | | - public static final String TOP_P_FIELD = "top_p"; |
30 | | - public static final String USER_FIELD = "user"; |
31 | | - public static final String STREAM_FIELD = "stream"; |
32 | | - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; |
| 19 | + // TODO remove this if EIS doesn't use it |
33 | 20 | private static final String MODEL_FIELD = "model"; |
34 | | - public static final String MESSAGES_FIELD = "messages"; |
35 | | - private static final String ROLE_FIELD = "role"; |
36 | | - private static final String CONTENT_FIELD = "content"; |
37 | | - private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; |
38 | | - private static final String STOP_FIELD = "stop"; |
39 | | - private static final String TEMPERATURE_FIELD = "temperature"; |
40 | | - private static final String TOOL_CHOICE_FIELD = "tool_choice"; |
41 | | - private static final String TOOL_FIELD = "tools"; |
42 | | - private static final String TEXT_FIELD = "text"; |
43 | | - private static final String TYPE_FIELD = "type"; |
44 | | - private static final String STREAM_OPTIONS_FIELD = "stream_options"; |
45 | | - private static final String INCLUDE_USAGE_FIELD = "include_usage"; |
46 | | - |
47 | | - private final UnifiedCompletionRequest unifiedRequest; |
48 | | - private final boolean stream; |
49 | 21 |
|
50 | | - public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { |
51 | | - Objects.requireNonNull(unifiedChatInput); |
| 22 | + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; |
| 23 | + private final String modelId; |
52 | 24 |
|
53 | | - this.unifiedRequest = unifiedChatInput.getRequest(); |
54 | | - this.stream = unifiedChatInput.stream(); |
| 25 | + public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { |
| 26 | + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); |
| 27 | + this.modelId = Objects.requireNonNull(modelId); |
55 | 28 | } |
56 | 29 |
|
57 | 30 | @Override |
58 | 31 | public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { |
59 | 32 | builder.startObject(); |
60 | | - builder.startArray(MESSAGES_FIELD); |
61 | | - { |
62 | | - for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { |
63 | | - builder.startObject(); |
64 | | - { |
65 | | - switch (message.content()) { |
66 | | - case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); |
67 | | - case UnifiedCompletionRequest.ContentObjects contentObjects -> { |
68 | | - builder.startArray(CONTENT_FIELD); |
69 | | - for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { |
70 | | - builder.startObject(); |
71 | | - builder.field(TEXT_FIELD, contentObject.text()); |
72 | | - builder.field(TYPE_FIELD, contentObject.type()); |
73 | | - builder.endObject(); |
74 | | - } |
75 | | - builder.endArray(); |
76 | | - } |
77 | | - } |
78 | | - |
79 | | - builder.field(ROLE_FIELD, message.role()); |
80 | | - if (message.name() != null) { |
81 | | - builder.field(NAME_FIELD, message.name()); |
82 | | - } |
83 | | - if (message.toolCallId() != null) { |
84 | | - builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); |
85 | | - } |
86 | | - if (message.toolCalls() != null) { |
87 | | - builder.startArray(TOOL_CALLS_FIELD); |
88 | | - for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { |
89 | | - builder.startObject(); |
90 | | - { |
91 | | - builder.field(ID_FIELD, toolCall.id()); |
92 | | - builder.startObject(FUNCTION_FIELD); |
93 | | - { |
94 | | - builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); |
95 | | - builder.field(NAME_FIELD, toolCall.function().name()); |
96 | | - } |
97 | | - builder.endObject(); |
98 | | - builder.field(TYPE_FIELD, toolCall.type()); |
99 | | - } |
100 | | - builder.endObject(); |
101 | | - } |
102 | | - builder.endArray(); |
103 | | - } |
104 | | - } |
105 | | - builder.endObject(); |
106 | | - } |
107 | | - } |
108 | | - builder.endArray(); |
109 | | - |
110 | | - if (unifiedRequest.maxCompletionTokens() != null) { |
111 | | - builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); |
112 | | - } |
113 | | - |
114 | | - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); |
115 | | - |
116 | | - if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { |
117 | | - builder.field(STOP_FIELD, unifiedRequest.stop()); |
118 | | - } |
119 | | - if (unifiedRequest.temperature() != null) { |
120 | | - builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); |
121 | | - } |
122 | | - if (unifiedRequest.toolChoice() != null) { |
123 | | - if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { |
124 | | - builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); |
125 | | - } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { |
126 | | - builder.startObject(TOOL_CHOICE_FIELD); |
127 | | - { |
128 | | - builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); |
129 | | - builder.startObject(FUNCTION_FIELD); |
130 | | - { |
131 | | - builder.field( |
132 | | - NAME_FIELD, |
133 | | - ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() |
134 | | - ); |
135 | | - } |
136 | | - builder.endObject(); |
137 | | - } |
138 | | - builder.endObject(); |
139 | | - } |
140 | | - } |
141 | | - if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { |
142 | | - builder.startArray(TOOL_FIELD); |
143 | | - for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { |
144 | | - builder.startObject(); |
145 | | - { |
146 | | - builder.field(TYPE_FIELD, t.type()); |
147 | | - builder.startObject(FUNCTION_FIELD); |
148 | | - { |
149 | | - builder.field(DESCRIPTION_FIELD, t.function().description()); |
150 | | - builder.field(NAME_FIELD, t.function().name()); |
151 | | - builder.field(PARAMETERS_FIELD, t.function().parameters()); |
152 | | - if (t.function().strict() != null) { |
153 | | - builder.field(STRICT_FIELD, t.function().strict()); |
154 | | - } |
155 | | - } |
156 | | - builder.endObject(); |
157 | | - } |
158 | | - builder.endObject(); |
159 | | - } |
160 | | - builder.endArray(); |
161 | | - } |
162 | | - if (unifiedRequest.topP() != null) { |
163 | | - builder.field(TOP_P_FIELD, unifiedRequest.topP()); |
164 | | - } |
165 | | - |
166 | | - builder.field(STREAM_FIELD, stream); |
167 | | - if (stream) { |
168 | | - builder.startObject(STREAM_OPTIONS_FIELD); |
169 | | - builder.field(INCLUDE_USAGE_FIELD, true); |
170 | | - builder.endObject(); |
171 | | - } |
| 33 | + unifiedRequestEntity.toXContent(builder, params); |
| 34 | + // TODO remove this if EIS doesn't use it |
| 35 | + builder.field(MODEL_FIELD, modelId); |
172 | 36 | builder.endObject(); |
173 | 37 |
|
174 | 38 | return builder; |
|
0 commit comments