Skip to content

Commit 93ae756

Browse files
dev-jonghoonparknamsoo2
authored andcommitted
spring-projectsGH-3118: Fix NPE in OpenAIChatModel compatible streaming api with vertex ai gemini
Fixes: spring-projects#3118 Signed-off-by: jonghoon park <[email protected]> Signed-off-by: minsoo.nam <[email protected]>
1 parent 022bc27 commit 93ae756

File tree

2 files changed

+125
-3
lines changed

2 files changed

+125
-3
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
* @author Ilayaperumal Gopinathan
105105
* @author Alexandros Pappas
106106
* @author Soby Chacko
107+
* @author Jonghoon Park
107108
* @see ChatModel
108109
* @see StreamingChatModel
109110
* @see OpenAiApi
@@ -304,15 +305,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
304305
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
305306
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
306307
try {
307-
@SuppressWarnings("null")
308-
String id = chatCompletion2.id();
308+
// If an id is not provided, set to "NO_ID" (for compatible APIs).
309+
String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id();
309310

310311
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off
311312
if (choice.message().role() != null) {
312313
roleMap.putIfAbsent(id, choice.message().role().name());
313314
}
314315
Map<String, Object> metadata = Map.of(
315-
"id", chatCompletion2.id(),
316+
"id", id,
316317
"role", roleMap.getOrDefault(id, ""),
317318
"index", choice.index(),
318319
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.proxy;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21+
import org.springframework.ai.chat.messages.AssistantMessage;
22+
import org.springframework.ai.chat.messages.Message;
23+
import org.springframework.ai.chat.messages.UserMessage;
24+
import org.springframework.ai.chat.model.ChatResponse;
25+
import org.springframework.ai.chat.model.Generation;
26+
import org.springframework.ai.chat.prompt.Prompt;
27+
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
28+
import org.springframework.ai.model.SimpleApiKey;
29+
import org.springframework.ai.model.tool.ToolCallingManager;
30+
import org.springframework.ai.openai.OpenAiChatModel;
31+
import org.springframework.ai.openai.OpenAiChatOptions;
32+
import org.springframework.ai.openai.api.OpenAiApi;
33+
import org.springframework.beans.factory.annotation.Autowired;
34+
import org.springframework.beans.factory.annotation.Value;
35+
import org.springframework.boot.SpringBootConfiguration;
36+
import org.springframework.boot.test.context.SpringBootTest;
37+
import org.springframework.context.annotation.Bean;
38+
import org.springframework.core.io.Resource;
39+
import reactor.core.publisher.Flux;
40+
41+
import java.util.List;
42+
import java.util.Map;
43+
import java.util.stream.Collectors;
44+
45+
import static org.assertj.core.api.Assertions.assertThat;
46+
47+
/**
48+
* @author Jonghoon Park
49+
*/
50+
@SpringBootTest(classes = VertexAIGeminiWithOpenAiChatModelIT.Config.class)
51+
@EnabledIfEnvironmentVariable(named = "GEMINI_API_KEY", matches = ".+")
52+
class VertexAIGeminiWithOpenAiChatModelIT {
53+
54+
private static final String VERTEX_AI_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com";
55+
56+
private static final String VERTEX_AI_GEMINI_DEFAULT_MODEL = "gemini-2.0-flash";
57+
58+
@Value("classpath:/prompts/system-message.st")
59+
private Resource systemResource;
60+
61+
@Autowired
62+
private OpenAiChatModel chatModel;
63+
64+
@Test
65+
void roleTest() {
66+
UserMessage userMessage = new UserMessage(
67+
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
68+
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
69+
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
70+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
71+
ChatResponse response = this.chatModel.call(prompt);
72+
assertThat(response.getResults()).hasSize(1);
73+
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
74+
}
75+
76+
@Test
77+
void streamRoleTest() {
78+
UserMessage userMessage = new UserMessage(
79+
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
80+
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
81+
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
82+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
83+
Flux<ChatResponse> flux = this.chatModel.stream(prompt);
84+
85+
List<ChatResponse> responses = flux.collectList().block();
86+
assertThat(responses.size()).isGreaterThan(1);
87+
88+
String stitchedResponseContent = responses.stream()
89+
.map(ChatResponse::getResults)
90+
.flatMap(List::stream)
91+
.map(Generation::getOutput)
92+
.map(AssistantMessage::getText)
93+
.collect(Collectors.joining());
94+
95+
assertThat(stitchedResponseContent).contains("Blackbeard");
96+
}
97+
98+
@SpringBootConfiguration
99+
static class Config {
100+
101+
@Bean
102+
public OpenAiApi chatCompletionApi() {
103+
return OpenAiApi.builder()
104+
.baseUrl(VERTEX_AI_GEMINI_BASE_URL)
105+
.completionsPath("/v1beta/openai/chat/completions")
106+
.apiKey(new SimpleApiKey(System.getenv("GEMINI_API_KEY")))
107+
.build();
108+
}
109+
110+
@Bean
111+
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
112+
return OpenAiChatModel.builder()
113+
.openAiApi(openAiApi)
114+
.toolCallingManager(ToolCallingManager.builder().build())
115+
.defaultOptions(OpenAiChatOptions.builder().model(VERTEX_AI_GEMINI_DEFAULT_MODEL).build())
116+
.build();
117+
}
118+
119+
}
120+
121+
}

0 commit comments

Comments
 (0)