1212import com .sap .ai .sdk .foundationmodels .openai .OpenAiMessage ;
1313import com .sap .ai .sdk .foundationmodels .openai .OpenAiToolCall ;
1414import com .sap .ai .sdk .foundationmodels .openai .generated .model .ChatCompletionMessageToolCall ;
15- import com .sap .ai .sdk .foundationmodels .openai .generated .model .ChatCompletionResponseMessage ;
1615import com .sap .ai .sdk .foundationmodels .openai .generated .model .ChatCompletionTool ;
16+ import com .sap .ai .sdk .foundationmodels .openai .generated .model .CreateChatCompletionResponseChoicesInner ;
1717import com .sap .ai .sdk .foundationmodels .openai .generated .model .FunctionObject ;
1818import io .vavr .control .Option ;
19+ import java .math .BigDecimal ;
1920import java .util .ArrayList ;
2021import java .util .List ;
2122import java .util .Map ;
2728import org .springframework .ai .chat .messages .AssistantMessage .ToolCall ;
2829import org .springframework .ai .chat .messages .Message ;
2930import org .springframework .ai .chat .messages .ToolResponseMessage ;
31+ import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
3032import org .springframework .ai .chat .model .ChatModel ;
3133import org .springframework .ai .chat .model .ChatResponse ;
3234import org .springframework .ai .chat .model .Generation ;
35+ import org .springframework .ai .chat .prompt .ChatOptions ;
3336import org .springframework .ai .chat .prompt .Prompt ;
34- import org .springframework .ai .model .tool .DefaultToolCallingChatOptions ;
3537import org .springframework .ai .model .tool .DefaultToolCallingManager ;
38+ import org .springframework .ai .model .tool .ToolCallingChatOptions ;
3639import reactor .core .publisher .Flux ;
3740
3841/**
@@ -50,34 +53,40 @@ public class OpenAiChatModel implements ChatModel {
5053 @ Override
5154 @ Nonnull
5255 public ChatResponse call (@ Nonnull final Prompt prompt ) {
53- val openAiRequest = toOpenAiRequest ( prompt );
54- var request = new OpenAiChatCompletionRequest (openAiRequest );
56+ val options = prompt . getOptions ( );
57+ var request = new OpenAiChatCompletionRequest (extractMessages ( prompt ) );
5558
56- if ((prompt .getOptions () instanceof DefaultToolCallingChatOptions options )) {
57- request = request .withTools (extractTools (options ));
59+ if (options != null ) {
60+ request = extractOptions (request , options );
61+ }
62+ if ((options instanceof ToolCallingChatOptions toolOptions )) {
63+ request = request .withTools (extractTools (toolOptions ));
5864 }
5965
6066 val result = client .chatCompletion (request );
6167 val response = new ChatResponse (toGenerations (result ));
6268
63- if (prompt .getOptions () != null
64- && isInternalToolExecutionEnabled (prompt .getOptions ())
65- && response .hasToolCalls ()) {
69+ if (options != null && isInternalToolExecutionEnabled (options ) && response .hasToolCalls ()) {
6670 val toolExecutionResult = toolCallingManager .executeToolCalls (prompt , response );
6771 // Send the tool execution result back to the model.
68- return call (new Prompt (toolExecutionResult .conversationHistory (), prompt . getOptions () ));
72+ return call (new Prompt (toolExecutionResult .conversationHistory (), options ));
6973 }
7074 return response ;
7175 }
7276
7377 @ Override
7478 @ Nonnull
7579 public Flux <ChatResponse > stream (@ Nonnull final Prompt prompt ) {
76- val openAiRequest = toOpenAiRequest (prompt );
77- var request = new OpenAiChatCompletionRequest (openAiRequest );
78- if ((prompt .getOptions () instanceof DefaultToolCallingChatOptions options )) {
79- request = request .withTools (extractTools (options ));
80+ val options = prompt .getOptions ();
81+ var request = new OpenAiChatCompletionRequest (extractMessages (prompt ));
82+
83+ if (options != null ) {
84+ request = extractOptions (request , options );
85+ }
86+ if ((options instanceof ToolCallingChatOptions toolOptions )) {
87+ request = request .withTools (extractTools (toolOptions ));
8088 }
89+
8190 val stream = client .streamChatCompletionDeltas (request );
8291 final Flux <OpenAiChatCompletionDelta > flux =
8392 Flux .generate (
@@ -90,36 +99,16 @@ public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
9099 }
91100 return iterator ;
92101 });
93- return flux .map (OpenAiChatModel ::toChatResponse );
94- }
95-
96- private List <ChatCompletionTool > extractTools (final DefaultToolCallingChatOptions options ) {
97- val tools = new ArrayList <ChatCompletionTool >();
98- for (val toolCallback : options .getToolCallbacks ()) {
99- val toolDefinition = toolCallback .getToolDefinition ();
100- try {
101- final Map <String , Object > params =
102- new ObjectMapper ().readValue (toolDefinition .inputSchema (), new TypeReference <>() {});
103- val toolType = ChatCompletionTool .TypeEnum .FUNCTION ;
104- val toolFunction =
105- new FunctionObject ()
106- .name (toolDefinition .name ())
107- .description (toolDefinition .description ())
108- .parameters (params );
109- val tool = new ChatCompletionTool ().type (toolType ).function (toolFunction );
110- tools .add (tool );
111- } catch (JsonProcessingException ignored ) {
112- }
113- }
114- return tools ;
102+ return flux .map (
103+ delta -> {
104+ val assistantMessage = new AssistantMessage (delta .getDeltaContent (), Map .of ());
105+ val metadata =
106+ ChatGenerationMetadata .builder ().finishReason (delta .getFinishReason ()).build ();
107+ return new ChatResponse (List .of (new Generation (assistantMessage , metadata )));
108+ });
115109 }
116110
117- private static ChatResponse toChatResponse (final OpenAiChatCompletionDelta delta ) {
118- val assistantMessage = new AssistantMessage (delta .getDeltaContent (), Map .of ());
119- return new ChatResponse (List .of (new Generation (assistantMessage )));
120- }
121-
122- private List <OpenAiMessage > toOpenAiRequest (final Prompt prompt ) {
111+ private List <OpenAiMessage > extractMessages (final Prompt prompt ) {
123112 final List <OpenAiMessage > result = new ArrayList <>();
124113 for (final Message message : prompt .getInstructions ()) {
125114 switch (message .getMessageType ()) {
@@ -152,24 +141,73 @@ private static void addToolMessages(
152141 }
153142
154143 @ Nonnull
155- static List <Generation > toGenerations (@ Nonnull final OpenAiChatCompletionResponse result ) {
144+ private static List <Generation > toGenerations (
145+ @ Nonnull final OpenAiChatCompletionResponse result ) {
156146 return result .getOriginalResponse ().getChoices ().stream ()
157- .map (message -> toGeneration ( message . getMessage ()) )
147+ .map (OpenAiChatModel :: toGeneration )
158148 .toList ();
159149 }
160150
161151 @ Nonnull
162- static Generation toGeneration (@ Nonnull final ChatCompletionResponseMessage choice ) {
163- // no metadata for now
152+ private static Generation toGeneration (
153+ @ Nonnull final CreateChatCompletionResponseChoicesInner choice ) {
154+ val metadata =
155+ ChatGenerationMetadata .builder ().finishReason (choice .getFinishReason ().getValue ());
156+ metadata .metadata ("index" , choice .getIndex ());
157+ if (choice .getLogprobs () != null && !choice .getLogprobs ().getContent ().isEmpty ()) {
158+ metadata .metadata ("logprobs" , choice .getLogprobs ().getContent ());
159+ }
160+ val message = choice .getMessage ();
164161 val calls = new ArrayList <ToolCall >();
165- if (choice .getToolCalls () != null ) {
166- for (final ChatCompletionMessageToolCall c : choice .getToolCalls ()) {
162+ if (message .getToolCalls () != null ) {
163+ for (final ChatCompletionMessageToolCall c : message .getToolCalls ()) {
167164 val fnc = c .getFunction ();
168165 calls .add (
169166 new ToolCall (c .getId (), c .getType ().getValue (), fnc .getName (), fnc .getArguments ()));
170167 }
171168 }
172- val message = new AssistantMessage (choice .getContent (), Map .of (), calls );
173- return new Generation (message );
169+
170+ val assistantMessage = new AssistantMessage (message .getContent (), Map .of (), calls );
171+ return new Generation (assistantMessage , metadata .build ());
172+ }
173+
174+ private OpenAiChatCompletionRequest extractOptions (
175+ @ Nonnull OpenAiChatCompletionRequest request , @ Nonnull final ChatOptions options ) {
176+ request = request .withStop (options .getStopSequences ()).withMaxTokens (options .getMaxTokens ());
177+ if (options .getTemperature () != null ) {
178+ request = request .withTemperature (BigDecimal .valueOf (options .getTemperature ()));
179+ }
180+ if (options .getTopP () != null ) {
181+ request = request .withTopP (BigDecimal .valueOf (options .getTopP ()));
182+ }
183+ if (options .getPresencePenalty () != null ) {
184+ request = request .withPresencePenalty (BigDecimal .valueOf (options .getPresencePenalty ()));
185+ }
186+ if (options .getFrequencyPenalty () != null ) {
187+ request = request .withFrequencyPenalty (BigDecimal .valueOf (options .getFrequencyPenalty ()));
188+ }
189+ return request ;
190+ }
191+
192+ private List <ChatCompletionTool > extractTools (final ToolCallingChatOptions options ) {
193+ val tools = new ArrayList <ChatCompletionTool >();
194+ for (val toolCallback : options .getToolCallbacks ()) {
195+ val toolDefinition = toolCallback .getToolDefinition ();
196+ try {
197+ final Map <String , Object > params =
198+ new ObjectMapper ().readValue (toolDefinition .inputSchema (), new TypeReference <>() {});
199+ val tool =
200+ new ChatCompletionTool ()
201+ .type (ChatCompletionTool .TypeEnum .FUNCTION )
202+ .function (
203+ new FunctionObject ()
204+ .name (toolDefinition .name ())
205+ .description (toolDefinition .description ())
206+ .parameters (params ));
207+ tools .add (tool );
208+ } catch (JsonProcessingException ignored ) {
209+ }
210+ }
211+ return tools ;
174212 }
175213}
0 commit comments