77
88package  org .elasticsearch .xpack .inference .external .request .openai ;
99
10- import  org .elasticsearch .common .Strings ;
11- import  org .elasticsearch .core .Nullable ;
1210import  org .elasticsearch .xcontent .ToXContentObject ;
1311import  org .elasticsearch .xcontent .XContentBuilder ;
1412import  org .elasticsearch .xpack .core .inference .action .UnifiedCompletionRequest ;
1513import  org .elasticsearch .xpack .inference .external .http .sender .DocumentsOnlyInput ;
14+ import  org .elasticsearch .xpack .inference .external .request .UnifiedRequest ;
1615
1716import  java .io .IOException ;
1817import  java .util .List ;
19- import  java .util .Objects ;
2018
2119public  class  OpenAiUnifiedChatCompletionRequestEntity  implements  ToXContentObject  {
2220
23-     private  static  final  String  MESSAGES_FIELD  = "messages" ;
24-     private  static  final  String  MODEL_FIELD  = "model" ;
25- 
21+     public  static  final  String  NAME_FIELD  = "name" ;
22+     public  static  final  String  TOOL_CALL_ID_FIELD  = "tool_call_id" ;
23+     public  static  final  String  TOOL_CALLS_FIELD  = "tool_calls" ;
24+     public  static  final  String  ID_FIELD  = "id" ;
25+     public  static  final  String  FUNCTION_FIELD  = "function" ;
26+     public  static  final  String  ARGUMENTS_FIELD  = "arguments" ;
27+     public  static  final  String  DESCRIPTION_FIELD  = "description" ;
28+     public  static  final  String  PARAMETERS_FIELD  = "parameters" ;
29+     public  static  final  String  STRICT_FIELD  = "strict" ;
30+     public  static  final  String  TOP_P_FIELD  = "top_p" ;
31+     public  static  final  String  USER_FIELD  = "user" ;
32+     public  static  final  String  STREAM_FIELD  = "stream" ;
2633    private  static  final  String  NUMBER_OF_RETURNED_CHOICES_FIELD  = "n" ;
27- 
34+     private  static  final  String  MODEL_FIELD  = "model" ;
35+     public  static  final  String  MESSAGES_FIELD  = "messages" ;
2836    private  static  final  String  ROLE_FIELD  = "role" ;
29-     private  static  final  String  USER_FIELD  = "user" ;
3037    private  static  final  String  CONTENT_FIELD  = "content" ;
31-     private  static  final  String  STREAM_FIELD  = "stream" ;
3238    private  static  final  String  MAX_COMPLETION_TOKENS_FIELD  = "max_completion_tokens" ;
3339    private  static  final  String  STOP_FIELD  = "stop" ;
3440    private  static  final  String  TEMPERATURE_FIELD  = "temperature" ;
3541    private  static  final  String  TOOL_CHOICE_FIELD  = "tool_choice" ;
3642    private  static  final  String  TOOL_FIELD  = "tool" ;
37-     private  static  final  String  TOP_P_FIELD  = "top_p" ;
43+     private  static  final  String  TEXT_FIELD  = "text" ;
44+     private  static  final  String  TYPE_FIELD  = "type" ;
3845
39-     private  final  String   user ;
46+     private  final  UnifiedRequest   unifiedRequest ;
4047
41-     public  boolean   isStream ( ) {
42-         return   stream ;
48+     public  OpenAiUnifiedChatCompletionRequestEntity ( UnifiedRequest   unifiedRequest ) {
49+         this . unifiedRequest  =  unifiedRequest ;
4350    }
4451
45-     private  final  boolean  stream ;
46-     private  final  Long  maxCompletionTokens ;
47-     private  final  Integer  n ;
48-     private  final  UnifiedCompletionRequest .Stop  stop ;
49-     private  final  Float  temperature ;
50-     private  final  UnifiedCompletionRequest .ToolChoice  toolChoice ;
51-     private  final  List <UnifiedCompletionRequest .Tool > tool ;
52-     private  final  Float  topP ;
53-     private  final  List <UnifiedCompletionRequest .Message > messages ;
54-     private  final  String  model ;
55- 
5652    public  OpenAiUnifiedChatCompletionRequestEntity (DocumentsOnlyInput  input ) {
57-         this (convertDocumentsOnlyInputToMessages (input ), null , null , null , null , null , null , null , null , null );
53+         this (new   UnifiedRequest ( convertDocumentsOnlyInputToMessages (input ), null , null , null , null , null , null , null , null , null ,  true ) );
5854    }
5955
6056    private  static  List <UnifiedCompletionRequest .Message > convertDocumentsOnlyInputToMessages (DocumentsOnlyInput  input ) {
6157        return  input .getInputs ()
6258            .stream ()
63-             .map (doc  -> new  UnifiedCompletionRequest .Message (new  UnifiedCompletionRequest .ContentString (doc ), "user" , null , null , null ))
59+             .map (doc  -> new  UnifiedCompletionRequest .Message (new  UnifiedCompletionRequest .ContentString (doc ), USER_FIELD , null , null , null ))
6460            .toList ();
6561    }
6662
67-     public  OpenAiUnifiedChatCompletionRequestEntity (
68-         List <UnifiedCompletionRequest .Message > messages ,
69-         @ Nullable  String  model ,
70-         @ Nullable  Long  maxCompletionTokens ,
71-         @ Nullable  Integer  n ,
72-         @ Nullable  UnifiedCompletionRequest .Stop  stop ,
73-         @ Nullable  Float  temperature ,
74-         @ Nullable  UnifiedCompletionRequest .ToolChoice  toolChoice ,
75-         @ Nullable  List <UnifiedCompletionRequest .Tool > tool ,
76-         @ Nullable  Float  topP ,
77-         @ Nullable  String  user 
78-     ) {
79-         Objects .requireNonNull (messages );
80-         Objects .requireNonNull (model );
81- 
82-         this .user  = user ;
83-         this .stream  = true ; // always stream in unified API 
84-         this .maxCompletionTokens  = maxCompletionTokens ;
85-         this .n  = n ;
86-         this .stop  = stop ;
87-         this .temperature  = temperature ;
88-         this .toolChoice  = toolChoice ;
89-         this .tool  = tool ;
90-         this .topP  = topP ;
91-         this .messages  = messages ;
92-         this .model  = model ;
93- 
94-     }
95- 
9663    @ Override 
9764    public  XContentBuilder  toXContent (XContentBuilder  builder , Params  params ) throws  IOException  {
9865        builder .startObject ();
9966        builder .startArray (MESSAGES_FIELD );
10067        {
101-             for  (UnifiedCompletionRequest .Message  message  : messages ) {
68+             for  (UnifiedCompletionRequest .Message  message  : unifiedRequest . messages () ) {
10269                builder .startObject ();
10370                {
10471                    switch  (message .content ()) {
10572                        case  UnifiedCompletionRequest .ContentString  contentString  -> builder .field (CONTENT_FIELD , contentString .content ());
10673                        case  UnifiedCompletionRequest .ContentObjects  contentObjects  -> {
10774                            for  (UnifiedCompletionRequest .ContentObject  contentObject  : contentObjects .contentObjects ()) {
10875                                builder .startObject (CONTENT_FIELD );
109-                                 builder .field ("text" , contentObject .text ());
110-                                 builder .field ("type" , contentObject .type ());
76+                                 builder .field (TEXT_FIELD , contentObject .text ());
77+                                 builder .field (TYPE_FIELD , contentObject .type ());
11178                                builder .endObject ();
11279                            }
11380                        }
11481                    }
11582
11683                    builder .field (ROLE_FIELD , message .role ());
11784                    if  (message .name () != null ) {
118-                         builder .field ("name" , message .name ());
85+                         builder .field (NAME_FIELD , message .name ());
11986                    }
12087                    if  (message .toolCallId () != null ) {
121-                         builder .field ("tool_call_id" , message .toolCallId ());
88+                         builder .field (TOOL_CALL_ID_FIELD , message .toolCallId ());
12289                    }
12390                    if  (message .toolCalls () != null ) {
124-                         builder .startArray ("tool_calls" );
91+                         builder .startArray (TOOL_CALLS_FIELD );
12592                        for  (UnifiedCompletionRequest .ToolCall  toolCall  : message .toolCalls ()) {
12693                            builder .startObject ();
12794                            {
128-                                 builder .field ("id" , toolCall .id ());
129-                                 builder .startObject ("function" );
95+                                 builder .field (ID_FIELD , toolCall .id ());
96+                                 builder .startObject (FUNCTION_FIELD );
13097                                {
131-                                     builder .field ("arguments" , toolCall .function ().arguments ());
132-                                     builder .field ("name" , toolCall .function ().name ());
98+                                     builder .field (ARGUMENTS_FIELD , toolCall .function ().arguments ());
99+                                     builder .field (NAME_FIELD , toolCall .function ().name ());
133100                                }
134101                                builder .endObject ();
135-                                 builder .field ("type" , toolCall .type ());
102+                                 builder .field (TYPE_FIELD , toolCall .type ());
136103                            }
137104                            builder .endObject ();
138105                        }
@@ -144,65 +111,69 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
144111        }
145112        builder .endArray ();
146113
147-         if  (model  != null ) {
148-             builder .field (MODEL_FIELD , model );
114+         if  (unifiedRequest . model ()  != null ) {
115+             builder .field (MODEL_FIELD , unifiedRequest . model () );
149116        }
150-         if  (maxCompletionTokens  != null ) {
151-             builder .field (MAX_COMPLETION_TOKENS_FIELD , maxCompletionTokens );
117+         if  (unifiedRequest . maxCompletionTokens ()  != null ) {
118+             builder .field (MAX_COMPLETION_TOKENS_FIELD , unifiedRequest . maxCompletionTokens () );
152119        }
153-         if  (n  != null ) {
154-             builder .field (NUMBER_OF_RETURNED_CHOICES_FIELD , n );
120+         if  (unifiedRequest . n ()  != null ) {
121+             builder .field (NUMBER_OF_RETURNED_CHOICES_FIELD , unifiedRequest . n () );
155122        }
156-         if  (stop  != null ) {
157-             switch  (stop ) {
123+         if  (unifiedRequest . stop ()  != null ) {
124+             switch  (unifiedRequest . stop () ) {
158125                case  UnifiedCompletionRequest .StopString  stopString  -> builder .field (STOP_FIELD , stopString .value ());
159126                case  UnifiedCompletionRequest .StopValues  stopValues  -> builder .field (STOP_FIELD , stopValues .values ());
160127            }
161128        }
162-         if  (temperature  != null ) {
163-             builder .field (TEMPERATURE_FIELD , temperature );
129+         if  (unifiedRequest . temperature ()  != null ) {
130+             builder .field (TEMPERATURE_FIELD , unifiedRequest . temperature () );
164131        }
165-         if  (toolChoice  != null ) {
166-             if  (toolChoice  instanceof  UnifiedCompletionRequest .ToolChoiceString ) {
167-                 builder .field (TOOL_CHOICE_FIELD , ((UnifiedCompletionRequest .ToolChoiceString ) toolChoice ).value ());
168-             } else  if  (toolChoice  instanceof  UnifiedCompletionRequest .ToolChoiceObject ) {
132+         if  (unifiedRequest . toolChoice ()  != null ) {
133+             if  (unifiedRequest . toolChoice ()  instanceof  UnifiedCompletionRequest .ToolChoiceString ) {
134+                 builder .field (TOOL_CHOICE_FIELD , ((UnifiedCompletionRequest .ToolChoiceString ) unifiedRequest . toolChoice () ).value ());
135+             } else  if  (unifiedRequest . toolChoice ()  instanceof  UnifiedCompletionRequest .ToolChoiceObject ) {
169136                builder .startObject (TOOL_CHOICE_FIELD );
170137                {
171-                     builder .field ("type" , ((UnifiedCompletionRequest .ToolChoiceObject ) toolChoice ).type ());
172-                     builder .startObject ("function" );
138+                     builder .field (TYPE_FIELD , ((UnifiedCompletionRequest .ToolChoiceObject ) unifiedRequest . toolChoice () ).type ());
139+                     builder .startObject (FUNCTION_FIELD );
173140                    {
174-                         builder .field ("name" , ((UnifiedCompletionRequest .ToolChoiceObject ) toolChoice ).function ().name ());
141+                         builder .field (
142+                             NAME_FIELD ,
143+                             ((UnifiedCompletionRequest .ToolChoiceObject ) unifiedRequest .toolChoice ()).function ().name ()
144+                         );
175145                    }
176146                    builder .endObject ();
177147                }
178148                builder .endObject ();
179149            }
180150        }
181-         if  (tool  != null ) {
151+         if  (unifiedRequest . tool ()  != null ) {
182152            builder .startArray (TOOL_FIELD );
183-             for  (UnifiedCompletionRequest .Tool  t  : tool ) {
153+             for  (UnifiedCompletionRequest .Tool  t  : unifiedRequest . tool () ) {
184154                builder .startObject ();
185155                {
186-                     builder .field ("type" , t .type ());
187-                     builder .startObject ("function" );
156+                     builder .field (TYPE_FIELD , t .type ());
157+                     builder .startObject (FUNCTION_FIELD );
188158                    {
189-                         builder .field ("description" , t .function ().description ());
190-                         builder .field ("name" , t .function ().name ());
191-                         builder .field ("parameters" , t .function ().parameters ());
192-                         builder .field ("strict" , t .function ().strict ());
159+                         builder .field (DESCRIPTION_FIELD , t .function ().description ());
160+                         builder .field (NAME_FIELD , t .function ().name ());
161+                         builder .field (PARAMETERS_FIELD , t .function ().parameters ());
162+                         builder .field (STRICT_FIELD , t .function ().strict ());
193163                    }
194164                    builder .endObject ();
195165                }
196166                builder .endObject ();
197167            }
198168            builder .endArray ();
199169        }
200-         if  (topP  != null ) {
201-             builder .field (TOP_P_FIELD , topP );
170+         if  (unifiedRequest . topP ()  != null ) {
171+             builder .field (TOP_P_FIELD , unifiedRequest . topP () );
202172        }
203-         if  (Strings . isNullOrEmpty ( user ) == false ) {
204-             builder .field (USER_FIELD , user );
173+         if  (unifiedRequest . user () !=  null  &&  unifiedRequest . user (). isEmpty ( ) == false ) {
174+             builder .field (USER_FIELD , unifiedRequest . user () );
205175        }
176+         builder .field (STREAM_FIELD , unifiedRequest .stream ());
206177        builder .endObject ();
207178        return  builder ;
208179    }
0 commit comments