Skip to content

Commit f5761de

Browse files
committed
Fix Moonshot Chat model toolcalling token usage
- Accumulate the token usage when toolcalling is invoked - Fix both call() and stream() methods - Add `usage` field to the Chat completion choice as the usage is returned via Choice - Add Mootshot chatmodel ITs for functioncalling tests Move the tests into MoonshotChatModelFunctionCallingIT
1 parent 1c41c6a commit f5761de

File tree

5 files changed

+115
-12
lines changed

5 files changed

+115
-12
lines changed

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatModel.java

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3737
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
3838
import org.springframework.ai.chat.metadata.EmptyUsage;
39+
import org.springframework.ai.chat.metadata.Usage;
40+
import org.springframework.ai.chat.metadata.UsageUtils;
3941
import org.springframework.ai.chat.model.AbstractToolCallSupport;
4042
import org.springframework.ai.chat.model.ChatModel;
4143
import org.springframework.ai.chat.model.ChatResponse;
@@ -75,6 +77,7 @@
7577
*
7678
* @author Geng Rong
7779
* @author Alexandros Pappas
80+
* @author Ilayaperumal Gopinathan
7881
*/
7982
public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel {
8083

@@ -180,6 +183,10 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met
180183

181184
@Override
182185
public ChatResponse call(Prompt prompt) {
186+
return this.internalCall(prompt, null);
187+
}
188+
189+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
183190
ChatCompletionRequest request = createRequest(prompt, false);
184191

185192
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -218,8 +225,11 @@ public ChatResponse call(Prompt prompt) {
218225
// @formatter:on
219226
return buildGeneration(choice, metadata);
220227
}).toList();
221-
222-
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
228+
MoonshotApi.Usage usage = completionEntity.getBody().usage();
229+
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
230+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
231+
ChatResponse chatResponse = new ChatResponse(generations,
232+
from(completionEntity.getBody(), cumulativeUsage));
223233

224234
observationContext.setResponse(chatResponse);
225235

@@ -232,7 +242,7 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS
232242
var toolCallConversation = handleToolCalls(prompt, response);
233243
// Recursively call the call method with the tool call message
234244
// conversation that contains the call responses.
235-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
245+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
236246
}
237247
return response;
238248
}
@@ -244,6 +254,10 @@ public ChatOptions getDefaultOptions() {
244254

245255
@Override
246256
public Flux<ChatResponse> stream(Prompt prompt) {
257+
return this.internalStream(prompt, null);
258+
}
259+
260+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
247261
return Flux.deferContextual(contextView -> {
248262
ChatCompletionRequest request = createRequest(prompt, true);
249263

@@ -287,8 +301,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
287301
// @formatter:on
288302
return buildGeneration(choice, metadata);
289303
}).toList();
304+
MoonshotApi.Usage usage = chatCompletion2.usage();
305+
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
306+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
290307

291-
return new ChatResponse(generations, from(chatCompletion2));
308+
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
292309
}
293310
catch (Exception e) {
294311
logger.error("Error processing chat completion", e);
@@ -303,7 +320,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
303320
var toolCallConversation = handleToolCalls(prompt, response);
304321
// Recursively call the stream method with the tool call message
305322
// conversation that contains the call responses.
306-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
323+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
307324
}
308325
return Flux.just(response);
309326
})
@@ -325,6 +342,16 @@ private ChatResponseMetadata from(ChatCompletion result) {
325342
.build();
326343
}
327344

345+
private ChatResponseMetadata from(ChatCompletion result, Usage usage) {
346+
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
347+
return ChatResponseMetadata.builder()
348+
.id(result.id() != null ? result.id() : "")
349+
.usage(usage)
350+
.model(result.model() != null ? result.model() : "")
351+
.keyValue("created", result.created() != null ? result.created() : 0L)
352+
.build();
353+
}
354+
328355
/**
329356
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
330357
* @param chunk the ChatCompletionChunk to convert
@@ -336,10 +363,11 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
336363
if (delta == null) {
337364
delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT);
338365
}
339-
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason());
366+
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage());
340367
}).toList();
341-
342-
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
368+
// Get the usage from the latest choice
369+
MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage();
370+
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage);
343371
}
344372

345373
/**

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotApi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ public record Choice(
532532
// @formatter:off
533533
@JsonProperty("index") Integer index,
534534
@JsonProperty("message") ChatCompletionMessage message,
535-
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
535+
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
536+
@JsonProperty("usage") Usage usage) {
536537
// @formatter:on
537538
}
538539

models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/api/MoonshotStreamFunctionCallingHelper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
6464
: previous.finishReason());
6565
Integer index = (current.index() != null ? current.index() : previous.index());
6666

67+
MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage();
68+
6769
ChatCompletionMessage message = merge(previous.delta(), current.delta());
68-
return new ChunkChoice(index, message, finishReason, null);
70+
return new ChunkChoice(index, message, finishReason, usage);
6971
}
7072

7173
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {

models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotRetryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public void beforeEach() {
8080
public void moonshotChatTransientError() {
8181

8282
var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
83-
ChatCompletionFinishReason.STOP);
83+
ChatCompletionFinishReason.STOP, null);
8484
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model",
8585
List.of(choice), new MoonshotApi.Usage(10, 10, 10));
8686

models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/chat/MoonshotChatModelFunctionCallingIT.java

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.moonshot.chat;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
2021
import java.util.List;
2122
import java.util.Objects;
2223
import java.util.stream.Collectors;
@@ -53,6 +54,33 @@ class MoonshotChatModelFunctionCallingIT {
5354
@Autowired
5455
ChatModel chatModel;
5556

57+
private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool(
58+
MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function(
59+
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
60+
{
61+
"type": "object",
62+
"properties": {
63+
"location": {
64+
"type": "string",
65+
"description": "The city and state e.g. San Francisco, CA"
66+
},
67+
"lat": {
68+
"type": "number",
69+
"description": "The city latitude"
70+
},
71+
"lon": {
72+
"type": "number",
73+
"description": "The city longitude"
74+
},
75+
"unit": {
76+
"type": "string",
77+
"enum": ["C", "F"]
78+
}
79+
},
80+
"required": ["location", "lat", "lon", "unit"]
81+
}
82+
"""));
83+
5684
@Test
5785
void functionCallTest() {
5886

@@ -89,6 +117,7 @@ void streamFunctionCallTest() {
89117
.functionCallbacks(List.of(FunctionCallback.builder()
90118
.function("getCurrentWeather", new MockWeatherService())
91119
.description("Get the weather in location")
120+
.inputType(MockWeatherService.Request.class)
92121
.build()))
93122
.build();
94123

@@ -108,4 +137,47 @@ void streamFunctionCallTest() {
108137
assertThat(content).contains("30", "10", "15");
109138
}
110139

140+
@Test
141+
public void toolFunctionCallWithUsage() {
142+
var promptOptions = MoonshotChatOptions.builder()
143+
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
144+
.tools(Arrays.asList(FUNCTION_TOOL))
145+
.functionCallbacks(List.of(FunctionCallback.builder()
146+
.function("getCurrentWeather", new MockWeatherService())
147+
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
148+
.inputType(MockWeatherService.Request.class)
149+
.build()))
150+
.build();
151+
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
152+
promptOptions);
153+
154+
ChatResponse chatResponse = this.chatModel.call(prompt);
155+
assertThat(chatResponse).isNotNull();
156+
assertThat(chatResponse.getResult().getOutput());
157+
assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco");
158+
assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0");
159+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
160+
}
161+
162+
@Test
163+
public void testStreamFunctionCallUsage() {
164+
var promptOptions = MoonshotChatOptions.builder()
165+
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
166+
.tools(Arrays.asList(FUNCTION_TOOL))
167+
.functionCallbacks(List.of(FunctionCallback.builder()
168+
.function("getCurrentWeather", new MockWeatherService())
169+
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
170+
.inputType(MockWeatherService.Request.class)
171+
.build()))
172+
.build();
173+
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
174+
promptOptions);
175+
176+
ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast();
177+
assertThat(chatResponse).isNotNull();
178+
assertThat(chatResponse.getMetadata()).isNotNull();
179+
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
180+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
181+
}
182+
111183
}

0 commit comments

Comments
 (0)