Skip to content

Commit 92631a1

Browse files
dantelmomsftcopybara-github
authored andcommitted
fix: multiple tool requests with langchain4j
Merge #246 fix #239 . @glaforge COPYBARA_INTEGRATE_REVIEW=#246 from dantelmomsft:main bcb5431 PiperOrigin-RevId: 782170311
1 parent a348a30 commit 92631a1

File tree

3 files changed

+118
-17
lines changed

3 files changed

+118
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
target/
2+
.idea

contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -266,24 +266,25 @@ private ChatRequest toChatRequest(LlmRequest llmRequest) {
266266
}
267267

268268
private List<ChatMessage> toMessages(LlmRequest llmRequest) {
269-
List<ChatMessage> messages = new ArrayList<>();
270-
messages.addAll(llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList());
271-
messages.addAll(llmRequest.contents().stream().map(this::toChatMessage).toList());
269+
List<ChatMessage> messages =
270+
new ArrayList<>(
271+
llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList());
272+
llmRequest.contents().forEach(content -> messages.addAll(toChatMessage(content)));
272273
return messages;
273274
}
274275

275-
private ChatMessage toChatMessage(Content content) {
276+
private List<ChatMessage> toChatMessage(Content content) {
276277
String role = content.role().orElseThrow().toLowerCase();
277278
return switch (role) {
278279
case "user" -> toUserOrToolResultMessage(content);
279-
case "model", "assistant" -> toAiMessage(content);
280+
case "model", "assistant" -> List.of(toAiMessage(content));
280281
default -> throw new IllegalStateException("Unexpected role: " + role);
281282
};
282283
}
283284

284-
private ChatMessage toUserOrToolResultMessage(Content content) {
285-
ToolExecutionResultMessage toolExecutionResultMessage = null;
286-
ToolExecutionRequest toolExecutionRequest = null;
285+
private List<ChatMessage> toUserOrToolResultMessage(Content content) {
286+
List<ToolExecutionResultMessage> toolExecutionResultMessages = new ArrayList<>();
287+
List<ToolExecutionRequest> toolExecutionRequests = new ArrayList<>();
287288

288289
List<dev.langchain4j.data.message.Content> lc4jContents = new ArrayList<>();
289290

@@ -292,19 +293,19 @@ private ChatMessage toUserOrToolResultMessage(Content content) {
292293
lc4jContents.add(TextContent.from(part.text().get()));
293294
} else if (part.functionResponse().isPresent()) {
294295
FunctionResponse functionResponse = part.functionResponse().get();
295-
toolExecutionResultMessage =
296+
toolExecutionResultMessages.add(
296297
ToolExecutionResultMessage.from(
297298
functionResponse.id().orElseThrow(),
298299
functionResponse.name().orElseThrow(),
299-
toJson(functionResponse.response().orElseThrow()));
300+
toJson(functionResponse.response().orElseThrow())));
300301
} else if (part.functionCall().isPresent()) {
301302
FunctionCall functionCall = part.functionCall().get();
302-
toolExecutionRequest =
303+
toolExecutionRequests.add(
303304
ToolExecutionRequest.builder()
304305
.id(functionCall.id().orElseThrow())
305306
.name(functionCall.name().orElseThrow())
306307
.arguments(toJson(functionCall.args().orElse(Map.of())))
307-
.build();
308+
.build());
308309
} else if (part.inlineData().isPresent()) {
309310
Blob blob = part.inlineData().get();
310311

@@ -368,12 +369,15 @@ private ChatMessage toUserOrToolResultMessage(Content content) {
368369
}
369370
}
370371

371-
if (toolExecutionResultMessage != null) {
372-
return toolExecutionResultMessage;
373-
} else if (toolExecutionRequest != null) {
374-
return AiMessage.aiMessage(toolExecutionRequest);
372+
if (!toolExecutionResultMessages.isEmpty()) {
373+
return new ArrayList<ChatMessage>(toolExecutionResultMessages);
374+
} else if (!toolExecutionRequests.isEmpty()) {
375+
return toolExecutionRequests.stream()
376+
.map(AiMessage::aiMessage)
377+
.map(msg -> (ChatMessage) msg)
378+
.toList();
375379
} else {
376-
return UserMessage.from(lc4jContents);
380+
return List.of(UserMessage.from(lc4jContents));
377381
}
378382
}
379383

contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,102 @@ void testGenerateContentWithFunctionCall() {
162162
assertThat(functionCall.args().get()).containsEntry("city", "Paris");
163163
}
164164

165+
@Test
166+
@DisplayName("Should handle multiple function calls in LLM responses")
167+
void testGenerateContentWithMultipleFunctionCall() {
168+
// Given
169+
// Create mock FunctionTools
170+
final FunctionTool weatherTool = mock(FunctionTool.class);
171+
when(weatherTool.name()).thenReturn("getWeather");
172+
when(weatherTool.description()).thenReturn("Get weather for a city");
173+
174+
final FunctionTool timeTool = mock(FunctionTool.class);
175+
when(timeTool.name()).thenReturn("getCurrentTime");
176+
when(timeTool.description()).thenReturn("Get current time for a city");
177+
178+
// Create mock FunctionDeclarations
179+
final FunctionDeclaration weatherDeclaration = mock(FunctionDeclaration.class);
180+
final FunctionDeclaration timeDeclaration = mock(FunctionDeclaration.class);
181+
when(weatherTool.declaration()).thenReturn(Optional.of(weatherDeclaration));
182+
when(timeTool.declaration()).thenReturn(Optional.of(timeDeclaration));
183+
184+
// Create mock Schemas
185+
final Schema weatherSchema = mock(Schema.class);
186+
final Schema timeSchema = mock(Schema.class);
187+
when(weatherDeclaration.parameters()).thenReturn(Optional.of(weatherSchema));
188+
when(timeDeclaration.parameters()).thenReturn(Optional.of(timeSchema));
189+
190+
// Create mock Types
191+
final Type weatherType = mock(Type.class);
192+
final Type timeType = mock(Type.class);
193+
when(weatherSchema.type()).thenReturn(Optional.of(weatherType));
194+
when(timeSchema.type()).thenReturn(Optional.of(timeType));
195+
when(weatherType.knownEnum()).thenReturn(Type.Known.OBJECT);
196+
when(timeType.knownEnum()).thenReturn(Type.Known.OBJECT);
197+
198+
// Create mock schema properties
199+
when(weatherSchema.properties()).thenReturn(Optional.of(Map.of("city", weatherSchema)));
200+
when(timeSchema.properties()).thenReturn(Optional.of(Map.of("city", timeSchema)));
201+
when(weatherSchema.required()).thenReturn(Optional.of(List.of("city")));
202+
when(timeSchema.required()).thenReturn(Optional.of(List.of("city")));
203+
204+
// Create LlmRequest
205+
final LlmRequest llmRequest = LlmRequest.builder()
206+
.contents(List.of(Content.fromParts(Part.fromText("What's the weather in Paris and the current time?"))))
207+
.build();
208+
209+
// Mock multiple tool execution requests in the AI response
210+
final ToolExecutionRequest weatherRequest = ToolExecutionRequest.builder()
211+
.id("123")
212+
.name("getWeather")
213+
.arguments("{\"city\":\"Paris\"}")
214+
.build();
215+
216+
final ToolExecutionRequest timeRequest = ToolExecutionRequest.builder()
217+
.id("456")
218+
.name("getCurrentTime")
219+
.arguments("{\"city\":\"Paris\"}")
220+
.build();
221+
222+
final AiMessage aiMessage = AiMessage.builder()
223+
.text("")
224+
.toolExecutionRequests(List.of(weatherRequest, timeRequest))
225+
.build();
226+
227+
final ChatResponse chatResponse = mock(ChatResponse.class);
228+
when(chatResponse.aiMessage()).thenReturn(aiMessage);
229+
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);
230+
231+
// When
232+
final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst();
233+
234+
// Then
235+
assertThat(response).isNotNull();
236+
assertThat(response.content()).isPresent();
237+
assertThat(response.content().get().parts()).isPresent();
238+
239+
final List<Part> parts = response.content().get().parts().orElseThrow();
240+
assertThat(parts).hasSize(2);
241+
242+
// Verify first function call (getWeather)
243+
assertThat(parts.get(0).functionCall()).isPresent();
244+
final FunctionCall weatherCall = parts.get(0).functionCall().orElseThrow();
245+
assertThat(weatherCall.name()).isEqualTo(Optional.of("getWeather"));
246+
assertThat(weatherCall.args()).isPresent();
247+
assertThat(weatherCall.args().get()).containsEntry("city", "Paris");
248+
249+
// Verify second function call (getCurrentTime)
250+
assertThat(parts.get(1).functionCall()).isPresent();
251+
final FunctionCall timeCall = parts.get(1).functionCall().orElseThrow();
252+
assertThat(timeCall.name()).isEqualTo(Optional.of("getCurrentTime"));
253+
assertThat(timeCall.args()).isPresent();
254+
assertThat(timeCall.args().get()).containsEntry("city", "Paris");
255+
256+
// Verify the ChatModel was called
257+
verify(chatModel).chat(any(ChatRequest.class));
258+
}
259+
260+
165261
@Test
166262
@DisplayName("Should handle streaming responses correctly")
167263
void testGenerateContentWithStreamingChatModel() {

0 commit comments

Comments
 (0)