11package com .sap .ai .sdk .foundationmodels .openai .spring ;
22
3- import static org .springframework .ai .model .tool .ToolCallingChatOptions .isInternalToolExecutionEnabled ;
4-
53import com .fasterxml .jackson .core .JsonProcessingException ;
64import com .fasterxml .jackson .core .type .TypeReference ;
75import com .fasterxml .jackson .databind .ObjectMapper ;
1210import com .sap .ai .sdk .foundationmodels .openai .OpenAiMessage ;
1311import com .sap .ai .sdk .foundationmodels .openai .OpenAiToolCall ;
1412import com .sap .ai .sdk .foundationmodels .openai .generated .model .ChatCompletionMessageToolCall ;
15- import com .sap .ai .sdk .foundationmodels .openai .generated .model .ChatCompletionResponseMessage ;
1613import com .sap .ai .sdk .foundationmodels .openai .generated .model .ChatCompletionTool ;
14+ import com .sap .ai .sdk .foundationmodels .openai .generated .model .CreateChatCompletionResponseChoicesInner ;
1715import com .sap .ai .sdk .foundationmodels .openai .generated .model .FunctionObject ;
1816import io .vavr .control .Option ;
19- import java .util .ArrayList ;
20- import java .util .List ;
21- import java .util .Map ;
22- import java .util .function .Function ;
23- import javax .annotation .Nonnull ;
2417import lombok .RequiredArgsConstructor ;
2518import lombok .val ;
2619import org .springframework .ai .chat .messages .AssistantMessage ;
2720import org .springframework .ai .chat .messages .AssistantMessage .ToolCall ;
2821import org .springframework .ai .chat .messages .Message ;
2922import org .springframework .ai .chat .messages .ToolResponseMessage ;
23+ import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
3024import org .springframework .ai .chat .model .ChatModel ;
3125import org .springframework .ai .chat .model .ChatResponse ;
3226import org .springframework .ai .chat .model .Generation ;
27+ import org .springframework .ai .chat .prompt .ChatOptions ;
3328import org .springframework .ai .chat .prompt .Prompt ;
34- import org .springframework .ai .model .tool .DefaultToolCallingChatOptions ;
3529import org .springframework .ai .model .tool .DefaultToolCallingManager ;
30+ import org .springframework .ai .model .tool .ToolCallingChatOptions ;
3631import reactor .core .publisher .Flux ;
3732
33+ import javax .annotation .Nonnull ;
34+ import java .math .BigDecimal ;
35+ import java .util .ArrayList ;
36+ import java .util .List ;
37+ import java .util .Map ;
38+ import java .util .function .Function ;
39+
40+ import static org .springframework .ai .model .tool .ToolCallingChatOptions .isInternalToolExecutionEnabled ;
41+
3842/**
3943 * OpenAI Chat Model implementation that interacts with the OpenAI API to generate chat completions.
4044 */
@@ -50,34 +54,40 @@ public class OpenAiChatModel implements ChatModel {
5054 @ Override
5155 @ Nonnull
5256 public ChatResponse call (@ Nonnull final Prompt prompt ) {
53- val openAiRequest = toOpenAiRequest ( prompt );
54- var request = new OpenAiChatCompletionRequest (openAiRequest );
57+ val options = prompt . getOptions ( );
58+ var request = new OpenAiChatCompletionRequest (extractMessages ( prompt ) );
5559
56- if ((prompt .getOptions () instanceof DefaultToolCallingChatOptions options )) {
57- request = request .withTools (extractTools (options ));
60+ if (options != null ) {
61+ request = extractOptions (request , options );
62+ }
63+ if ((options instanceof ToolCallingChatOptions toolOptions )) {
64+ request = request .withTools (extractTools (toolOptions ));
5865 }
5966
6067 val result = client .chatCompletion (request );
6168 val response = new ChatResponse (toGenerations (result ));
6269
63- if (prompt .getOptions () != null
64- && isInternalToolExecutionEnabled (prompt .getOptions ())
65- && response .hasToolCalls ()) {
70+ if (options != null && isInternalToolExecutionEnabled (options ) && response .hasToolCalls ()) {
6671 val toolExecutionResult = toolCallingManager .executeToolCalls (prompt , response );
6772 // Send the tool execution result back to the model.
68- return call (new Prompt (toolExecutionResult .conversationHistory (), prompt . getOptions () ));
73+ return call (new Prompt (toolExecutionResult .conversationHistory (), options ));
6974 }
7075 return response ;
7176 }
7277
7378 @ Override
7479 @ Nonnull
7580 public Flux <ChatResponse > stream (@ Nonnull final Prompt prompt ) {
76- val openAiRequest = toOpenAiRequest (prompt );
77- var request = new OpenAiChatCompletionRequest (openAiRequest );
78- if ((prompt .getOptions () instanceof DefaultToolCallingChatOptions options )) {
79- request = request .withTools (extractTools (options ));
81+ val options = prompt .getOptions ();
82+ var request = new OpenAiChatCompletionRequest (extractMessages (prompt ));
83+
84+ if (options != null ) {
85+ request = extractOptions (request , options );
86+ }
87+ if ((options instanceof ToolCallingChatOptions toolOptions )) {
88+ request = request .withTools (extractTools (toolOptions ));
8089 }
90+
8191 val stream = client .streamChatCompletionDeltas (request );
8292 final Flux <OpenAiChatCompletionDelta > flux =
8393 Flux .generate (
@@ -90,37 +100,16 @@ public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
90100 }
91101 return iterator ;
92102 });
93- return flux .map (OpenAiChatModel ::toChatResponse );
103+ return flux .map (
104+ delta -> {
105+ val assistantMessage = new AssistantMessage (delta .getDeltaContent (), Map .of ());
106+ val metadata =
107+ ChatGenerationMetadata .builder ().finishReason (delta .getFinishReason ()).build ();
108+ return new ChatResponse (List .of (new Generation (assistantMessage , metadata )));
109+ });
94110 }
95111
96- private List <ChatCompletionTool > extractTools (final DefaultToolCallingChatOptions options ) {
97- val tools = new ArrayList <ChatCompletionTool >();
98- for (val toolCallback : options .getToolCallbacks ()) {
99- val toolDefinition = toolCallback .getToolDefinition ();
100- try {
101- final Map <String , Object > params =
102- new ObjectMapper ().readValue (toolDefinition .inputSchema (), new TypeReference <>() {});
103- val tool =
104- new ChatCompletionTool ()
105- .type (ChatCompletionTool .TypeEnum .FUNCTION )
106- .function (
107- new FunctionObject ()
108- .name (toolDefinition .name ())
109- .description (toolDefinition .description ())
110- .parameters (params ));
111- tools .add (tool );
112- } catch (JsonProcessingException ignored ) {
113- }
114- }
115- return tools ;
116- }
117-
118- private static ChatResponse toChatResponse (final OpenAiChatCompletionDelta delta ) {
119- val assistantMessage = new AssistantMessage (delta .getDeltaContent (), Map .of ());
120- return new ChatResponse (List .of (new Generation (assistantMessage )));
121- }
122-
123- private List <OpenAiMessage > toOpenAiRequest (final Prompt prompt ) {
112+ private List <OpenAiMessage > extractMessages (final Prompt prompt ) {
124113 final List <OpenAiMessage > result = new ArrayList <>();
125114 for (final Message message : prompt .getInstructions ()) {
126115 switch (message .getMessageType ()) {
@@ -153,24 +142,73 @@ private static void addToolMessages(
153142 }
154143
155144 @ Nonnull
156- static List <Generation > toGenerations (@ Nonnull final OpenAiChatCompletionResponse result ) {
145+ private static List <Generation > toGenerations (
146+ @ Nonnull final OpenAiChatCompletionResponse result ) {
157147 return result .getOriginalResponse ().getChoices ().stream ()
158- .map (message -> toGeneration ( message . getMessage ()) )
148+ .map (OpenAiChatModel :: toGeneration )
159149 .toList ();
160150 }
161151
162152 @ Nonnull
163- static Generation toGeneration (@ Nonnull final ChatCompletionResponseMessage choice ) {
164- // no metadata for now
153+ private static Generation toGeneration (
154+ @ Nonnull final CreateChatCompletionResponseChoicesInner choice ) {
155+ val metadata =
156+ ChatGenerationMetadata .builder ().finishReason (choice .getFinishReason ().getValue ());
157+ metadata .metadata ("index" , choice .getIndex ());
158+ if (choice .getLogprobs () != null && !choice .getLogprobs ().getContent ().isEmpty ()) {
159+ metadata .metadata ("logprobs" , choice .getLogprobs ().getContent ());
160+ }
161+ val message = choice .getMessage ();
165162 val calls = new ArrayList <ToolCall >();
166- if (choice .getToolCalls () != null ) {
167- for (final ChatCompletionMessageToolCall c : choice .getToolCalls ()) {
163+ if (message .getToolCalls () != null ) {
164+ for (final ChatCompletionMessageToolCall c : message .getToolCalls ()) {
168165 val fnc = c .getFunction ();
169166 calls .add (
170167 new ToolCall (c .getId (), c .getType ().getValue (), fnc .getName (), fnc .getArguments ()));
171168 }
172169 }
173- val message = new AssistantMessage (choice .getContent (), Map .of (), calls );
174- return new Generation (message );
170+
171+ val assistantMessage = new AssistantMessage (message .getContent (), Map .of (), calls );
172+ return new Generation (assistantMessage , metadata .build ());
173+ }
174+
175+ private OpenAiChatCompletionRequest extractOptions (
176+ @ Nonnull OpenAiChatCompletionRequest request , @ Nonnull final ChatOptions options ) {
177+ request = request .withStop (options .getStopSequences ()).withMaxTokens (options .getMaxTokens ());
178+ if (options .getTemperature () != null ) {
179+ request = request .withTemperature (BigDecimal .valueOf (options .getTemperature ()));
180+ }
181+ if (options .getTopP () != null ) {
182+ request = request .withTopP (BigDecimal .valueOf (options .getTopP ()));
183+ }
184+ if (options .getPresencePenalty () != null ) {
185+ request = request .withPresencePenalty (BigDecimal .valueOf (options .getPresencePenalty ()));
186+ }
187+ if (options .getFrequencyPenalty () != null ) {
188+ request = request .withFrequencyPenalty (BigDecimal .valueOf (options .getFrequencyPenalty ()));
189+ }
190+ return request ;
191+ }
192+
193+ private List <ChatCompletionTool > extractTools (final ToolCallingChatOptions options ) {
194+ val tools = new ArrayList <ChatCompletionTool >();
195+ for (val toolCallback : options .getToolCallbacks ()) {
196+ val toolDefinition = toolCallback .getToolDefinition ();
197+ try {
198+ final Map <String , Object > params =
199+ new ObjectMapper ().readValue (toolDefinition .inputSchema (), new TypeReference <>() {});
200+ val tool =
201+ new ChatCompletionTool ()
202+ .type (ChatCompletionTool .TypeEnum .FUNCTION )
203+ .function (
204+ new FunctionObject ()
205+ .name (toolDefinition .name ())
206+ .description (toolDefinition .description ())
207+ .parameters (params ));
208+ tools .add (tool );
209+ } catch (JsonProcessingException ignored ) {
210+ }
211+ }
212+ return tools ;
175213 }
176214}
0 commit comments