4848import org .slf4j .Logger ;
4949import org .slf4j .LoggerFactory ;
5050import reactor .core .publisher .Flux ;
51- import reactor .core .publisher .Mono ;
5251import reactor .core .scheduler .Schedulers ;
5352
5453import org .springframework .ai .chat .messages .AssistantMessage ;
6059import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
6160import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
6261import org .springframework .ai .chat .metadata .DefaultUsage ;
63- import org .springframework .ai .chat .model .AbstractToolCallSupport ;
62+ import org .springframework .ai .chat .metadata .EmptyUsage ;
63+ import org .springframework .ai .chat .metadata .Usage ;
64+ import org .springframework .ai .chat .metadata .UsageUtils ;
6465import org .springframework .ai .chat .model .ChatModel ;
6566import org .springframework .ai .chat .model .ChatResponse ;
6667import org .springframework .ai .chat .model .Generation ;
7172import org .springframework .ai .chat .observation .DefaultChatModelObservationConvention ;
7273import org .springframework .ai .chat .prompt .ChatOptions ;
7374import org .springframework .ai .chat .prompt .Prompt ;
74- import org .springframework .ai .model .ChatModelDescription ;
7575import org .springframework .ai .content .Media ;
76+ import org .springframework .ai .model .ChatModelDescription ;
7677import org .springframework .ai .model .ModelOptionsUtils ;
7778import org .springframework .ai .model .function .FunctionCallback ;
7879import org .springframework .ai .model .function .FunctionCallbackResolver ;
79- import org .springframework .ai .model .function .FunctionCallingOptions ;
8080import org .springframework .ai .model .tool .DefaultToolExecutionEligibilityPredicate ;
8181import org .springframework .ai .model .tool .LegacyToolCallingManager ;
8282import org .springframework .ai .model .tool .ToolCallingChatOptions ;
136136 * @author Soby Chacko
137137 * @author Jihoon Kim
138138 * @author Alexandros Pappas
139+ * @author Ilayaperumal Gopinathan
139140 * @since 0.8.1
140141 * @see VertexAiGeminiChatOptions
141142 * @see ToolCallingManager
142143 * @see ChatModel
143144 */
144- public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel , DisposableBean {
145+ public class VertexAiGeminiChatModel implements ChatModel , DisposableBean {
145146
146147 private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention ();
147148
@@ -277,8 +278,6 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defa
277278 ToolCallingManager toolCallingManager , RetryTemplate retryTemplate , ObservationRegistry observationRegistry ,
278279 ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate ) {
279280
280- super (null , VertexAiGeminiChatOptions .builder ().build (), List .of ());
281-
282281 Assert .notNull (vertexAI , "VertexAI must not be null" );
283282 Assert .notNull (defaultOptions , "VertexAiGeminiChatOptions must not be null" );
284283 Assert .notNull (defaultOptions .getModel (), "VertexAiGeminiChatOptions.modelName must not be null" );
@@ -425,10 +424,10 @@ private static Schema jsonToSchema(String json) {
425424 @ Override
426425 public ChatResponse call (Prompt prompt ) {
427426 var requestPrompt = this .buildRequestPrompt (prompt );
428- return this .internalCall (requestPrompt );
427+ return this .internalCall (requestPrompt , null );
429428 }
430429
431- private ChatResponse internalCall (Prompt prompt ) {
430+ private ChatResponse internalCall (Prompt prompt , ChatResponse previousChatResponse ) {
432431
433432 ChatModelObservationContext observationContext = ChatModelObservationContext .builder ()
434433 .prompt (prompt )
@@ -451,8 +450,12 @@ private ChatResponse internalCall(Prompt prompt) {
451450 .flatMap (List ::stream )
452451 .toList ();
453452
454- ChatResponse chatResponse = new ChatResponse (generations ,
455- toChatResponseMetadata (generateContentResponse ));
453+ GenerateContentResponse .UsageMetadata usage = generateContentResponse .getUsageMetadata ();
454+ Usage currentUsage = (usage != null )
455+ ? new DefaultUsage (usage .getPromptTokenCount (), usage .getCandidatesTokenCount ())
456+ : new EmptyUsage ();
457+ Usage cumulativeUsage = UsageUtils .getCumulativeUsage (currentUsage , previousChatResponse );
458+ ChatResponse chatResponse = new ChatResponse (generations , toChatResponseMetadata (cumulativeUsage ));
456459
457460 observationContext .setResponse (chatResponse );
458461 return chatResponse ;
@@ -469,7 +472,8 @@ private ChatResponse internalCall(Prompt prompt) {
469472 }
470473 else {
471474 // Send the tool execution result back to the model.
472- return this .internalCall (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()));
475+ return this .internalCall (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()),
476+ response );
473477 }
474478 }
475479
@@ -485,10 +489,6 @@ Prompt buildRequestPrompt(Prompt prompt) {
485489 runtimeOptions = ModelOptionsUtils .copyToTarget (toolCallingChatOptions , ToolCallingChatOptions .class ,
486490 VertexAiGeminiChatOptions .class );
487491 }
488- else if (prompt .getOptions () instanceof FunctionCallingOptions functionCallingOptions ) {
489- runtimeOptions = ModelOptionsUtils .copyToTarget (functionCallingOptions , FunctionCallingOptions .class ,
490- VertexAiGeminiChatOptions .class );
491- }
492492 else {
493493 runtimeOptions = ModelOptionsUtils .copyToTarget (prompt .getOptions (), ChatOptions .class ,
494494 VertexAiGeminiChatOptions .class );
@@ -535,10 +535,10 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
535535 @ Override
536536 public Flux <ChatResponse > stream (Prompt prompt ) {
537537 var requestPrompt = this .buildRequestPrompt (prompt );
538- return this .internalStream (requestPrompt );
538+ return this .internalStream (requestPrompt , null );
539539 }
540540
541- public Flux <ChatResponse > internalStream (Prompt prompt ) {
541+ public Flux <ChatResponse > internalStream (Prompt prompt , ChatResponse previousChatResponse ) {
542542 return Flux .deferContextual (contextView -> {
543543
544544 ChatModelObservationContext observationContext = ChatModelObservationContext .builder ()
@@ -559,21 +559,22 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
559559 ResponseStream <GenerateContentResponse > responseStream = request .model
560560 .generateContentStream (request .contents );
561561
562- Flux <ChatResponse > chatResponse1 = Flux .fromStream (responseStream .stream ())
563- .switchMap (response2 -> Mono .just (response2 ).map (response -> {
564-
565- List <Generation > generations = response .getCandidatesList ()
566- .stream ()
567- .map (this ::responseCandidateToGeneration )
568- .flatMap (List ::stream )
569- .toList ();
570-
571- return new ChatResponse (generations , toChatResponseMetadata (response ));
562+ Flux <ChatResponse > chatResponseFlux = Flux .fromStream (responseStream .stream ()).switchMap (response -> {
563+ List <Generation > generations = response .getCandidatesList ()
564+ .stream ()
565+ .map (this ::responseCandidateToGeneration )
566+ .flatMap (List ::stream )
567+ .toList ();
572568
573- }));
569+ GenerateContentResponse .UsageMetadata usage = response .getUsageMetadata ();
570+ Usage currentUsage = (usage != null ) ? getDefaultUsage (usage ) : new EmptyUsage ();
571+ Usage cumulativeUsage = UsageUtils .getCumulativeUsage (currentUsage , previousChatResponse );
572+ ChatResponse chatResponse = new ChatResponse (generations , toChatResponseMetadata (cumulativeUsage ));
573+ return Flux .just (chatResponse );
574+ });
574575
575576 // @formatter:off
576- Flux <ChatResponse > chatResponseFlux = chatResponse1 .flatMap (response -> {
577+ Flux <ChatResponse > flux = chatResponseFlux .flatMap (response -> {
577578 if (toolExecutionEligibilityPredicate .isToolExecutionRequired (prompt .getOptions (), response )) {
578579 // FIXME: bounded elastic needs to be used since tool calling
579580 // is currently only synchronous
@@ -586,7 +587,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
586587 .build ());
587588 } else {
588589 // Send the tool execution result back to the model.
589- return this .internalStream (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()));
590+ return this .internalStream (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()), response );
590591 }
591592 }).subscribeOn (Schedulers .boundedElastic ());
592593 }
@@ -599,7 +600,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
599600 .contextWrite (ctx -> ctx .put (ObservationThreadLocalAccessor .KEY , observation ));
600601 // @formatter:on;
601602
602- return new MessageAggregator ().aggregate (chatResponseFlux , observationContext ::setResponse );
603+ return new MessageAggregator ().aggregate (flux , observationContext ::setResponse );
603604
604605 }
605606 catch (Exception e ) {
@@ -653,8 +654,8 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
653654 }
654655 }
655656
656- private ChatResponseMetadata toChatResponseMetadata (GenerateContentResponse response ) {
657- return ChatResponseMetadata .builder ().usage (getDefaultUsage ( response . getUsageMetadata ()) ).build ();
657+ private ChatResponseMetadata toChatResponseMetadata (Usage usage ) {
658+ return ChatResponseMetadata .builder ().usage (usage ).build ();
658659 }
659660
660661 private DefaultUsage getDefaultUsage (GenerateContentResponse .UsageMetadata usageMetadata ) {
0 commit comments