Skip to content

Commit 296ba3b

Browse files
sunyuhan1998chedim
authored andcommitted
fix(ollama): null check on chat model response' duration metadata
- Fixed an issue where the `from` method of `org.springframework.ai.ollama.OllamaChatModel` could not work correctly when a tool call occurred while using Ollama. Signed-off-by: Sun Yuhan <[email protected]>
1 parent 67b07d3 commit 296ba3b

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
* @author Jihoon Kim
8585
* @author Alexandros Pappas
8686
* @author Ilayaperumal Gopinathan
87+
* @author Sun Yuhan
8788
* @since 1.0.0
8889
*/
8990
public class OllamaChatModel implements ChatModel {
@@ -170,18 +171,21 @@ static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse p
170171
Duration totalDuration = response.getTotalDuration();
171172

172173
if (previousChatResponse != null && previousChatResponse.getMetadata() != null) {
173-
if (previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION) != null) {
174-
evalDuration = evalDuration.plus(previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION));
174+
Object metadataEvalDuration = previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION);
175+
if (metadataEvalDuration != null && evalDuration != null) {
176+
evalDuration = evalDuration.plus((Duration) metadataEvalDuration);
175177
}
176-
if (previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION) != null) {
177-
promptEvalDuration = promptEvalDuration
178-
.plus(previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION));
178+
Object metadataPromptEvalDuration = previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION);
179+
if (metadataPromptEvalDuration != null && promptEvalDuration != null) {
180+
promptEvalDuration = promptEvalDuration.plus((Duration) metadataPromptEvalDuration);
179181
}
180-
if (previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION) != null) {
181-
loadDuration = loadDuration.plus(previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION));
182+
Object metadataLoadDuration = previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION);
183+
if (metadataLoadDuration != null && loadDuration != null) {
184+
loadDuration = loadDuration.plus((Duration) metadataLoadDuration);
182185
}
183-
if (previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION) != null) {
184-
totalDuration = totalDuration.plus(previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION));
186+
Object metadataTotalDuration = previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION);
187+
if (metadataTotalDuration != null && totalDuration != null) {
188+
totalDuration = totalDuration.plus((Duration) metadataTotalDuration);
185189
}
186190
if (previousChatResponse.getMetadata().getUsage() != null) {
187191
promptTokens += previousChatResponse.getMetadata().getUsage().getPromptTokens();

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
import org.springframework.ai.ollama.management.ModelManagementOptions;
3838

3939
import static org.assertj.core.api.Assertions.assertThat;
40-
import static org.junit.jupiter.api.Assertions.assertEquals;
41-
import static org.junit.jupiter.api.Assertions.assertThrows;
40+
import static org.junit.jupiter.api.Assertions.*;
4241

4342
/**
4443
* @author Jihoon Kim
@@ -146,4 +145,28 @@ void buildChatResponseMetadataAggregationWithNonEmptyMetadata() {
146145
assertEquals(promptEvalCount + 66, (Integer) metadata.get("prompt-eval-count"));
147146
}
148147

148+
@Test
149+
void buildChatResponseMetadataAggregationWithNonEmptyMetadataButEmptyEval() {
150+
151+
OllamaApi.ChatResponse response = new OllamaApi.ChatResponse("model", Instant.now(), null, null, null, null,
152+
null, null, null, null, null);
153+
154+
ChatResponse previousChatResponse = ChatResponse.builder()
155+
.generations(List.of())
156+
.metadata(ChatResponseMetadata.builder()
157+
.usage(new DefaultUsage(66, 99))
158+
.keyValue("eval-duration", Duration.ofSeconds(2))
159+
.keyValue("prompt-eval-duration", Duration.ofSeconds(2))
160+
.build())
161+
.build();
162+
163+
ChatResponseMetadata metadata = OllamaChatModel.from(response, previousChatResponse);
164+
165+
assertNull(metadata.get("eval-duration"));
166+
assertNull(metadata.get("prompt-eval-duration"));
167+
assertEquals(Integer.valueOf(99), metadata.get("eval-count"));
168+
assertEquals(Integer.valueOf(66), metadata.get("prompt-eval-count"));
169+
170+
}
171+
149172
}

0 commit comments

Comments
 (0)