diff --git a/.gitignore b/.gitignore index 2f7896d1d..15aedc70d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ target/ +.idea diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 69e8c761e..4b8bcd4e1 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -266,24 +266,25 @@ private ChatRequest toChatRequest(LlmRequest llmRequest) { } private List toMessages(LlmRequest llmRequest) { - List messages = new ArrayList<>(); - messages.addAll(llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList()); - messages.addAll(llmRequest.contents().stream().map(this::toChatMessage).toList()); + List messages = + new ArrayList<>( + llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList()); + llmRequest.contents().forEach(content -> messages.addAll(toChatMessage(content))); return messages; } - private ChatMessage toChatMessage(Content content) { + private List toChatMessage(Content content) { String role = content.role().orElseThrow().toLowerCase(); return switch (role) { case "user" -> toUserOrToolResultMessage(content); - case "model", "assistant" -> toAiMessage(content); + case "model", "assistant" -> List.of(toAiMessage(content)); default -> throw new IllegalStateException("Unexpected role: " + role); }; } - private ChatMessage toUserOrToolResultMessage(Content content) { - ToolExecutionResultMessage toolExecutionResultMessage = null; - ToolExecutionRequest toolExecutionRequest = null; + private List toUserOrToolResultMessage(Content content) { + List toolExecutionResultMessages = new ArrayList<>(); + List toolExecutionRequests = new ArrayList<>(); List lc4jContents = new ArrayList<>(); @@ -292,19 +293,19 @@ private ChatMessage toUserOrToolResultMessage(Content content) { lc4jContents.add(TextContent.from(part.text().get())); } else if (part.functionResponse().isPresent()) { FunctionResponse functionResponse = part.functionResponse().get(); - toolExecutionResultMessage = + toolExecutionResultMessages.add( ToolExecutionResultMessage.from( functionResponse.id().orElseThrow(), functionResponse.name().orElseThrow(), - toJson(functionResponse.response().orElseThrow())); + toJson(functionResponse.response().orElseThrow()))); } else if (part.functionCall().isPresent()) { FunctionCall functionCall = part.functionCall().get(); - toolExecutionRequest = + toolExecutionRequests.add( ToolExecutionRequest.builder() .id(functionCall.id().orElseThrow()) .name(functionCall.name().orElseThrow()) .arguments(toJson(functionCall.args().orElse(Map.of()))) - .build(); + .build()); } else if (part.inlineData().isPresent()) { Blob blob = part.inlineData().get(); @@ -368,12 +369,15 @@ private ChatMessage toUserOrToolResultMessage(Content content) { } } - if (toolExecutionResultMessage != null) { - return toolExecutionResultMessage; - } else if (toolExecutionRequest != null) { - return AiMessage.aiMessage(toolExecutionRequest); + if (!toolExecutionResultMessages.isEmpty()) { + return new ArrayList(toolExecutionResultMessages); + } else if (!toolExecutionRequests.isEmpty()) { + return toolExecutionRequests.stream() + .map(AiMessage::aiMessage) + .map(msg -> (ChatMessage) msg) + .toList(); } else { - return UserMessage.from(lc4jContents); + return List.of(UserMessage.from(lc4jContents)); } } diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 6217730a2..aec82284a 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -162,6 +162,108 @@ void testGenerateContentWithFunctionCall() { assertThat(functionCall.args().get()).containsEntry("city", "Paris"); } + @Test + @DisplayName("Should handle multiple function calls in LLM responses") + void testGenerateContentWithMultipleFunctionCall() { + // Given + // Create mock FunctionTools + final FunctionTool weatherTool = mock(FunctionTool.class); + when(weatherTool.name()).thenReturn("getWeather"); + when(weatherTool.description()).thenReturn("Get weather for a city"); + + final FunctionTool timeTool = mock(FunctionTool.class); + when(timeTool.name()).thenReturn("getCurrentTime"); + when(timeTool.description()).thenReturn("Get current time for a city"); + + // Create mock FunctionDeclarations + final FunctionDeclaration weatherDeclaration = mock(FunctionDeclaration.class); + final FunctionDeclaration timeDeclaration = mock(FunctionDeclaration.class); + when(weatherTool.declaration()).thenReturn(Optional.of(weatherDeclaration)); + when(timeTool.declaration()).thenReturn(Optional.of(timeDeclaration)); + + // Create mock Schemas + final Schema weatherSchema = mock(Schema.class); + final Schema timeSchema = mock(Schema.class); + when(weatherDeclaration.parameters()).thenReturn(Optional.of(weatherSchema)); + when(timeDeclaration.parameters()).thenReturn(Optional.of(timeSchema)); + + // Create mock Types + final Type weatherType = mock(Type.class); + final Type timeType = mock(Type.class); + when(weatherSchema.type()).thenReturn(Optional.of(weatherType)); + when(timeSchema.type()).thenReturn(Optional.of(timeType)); + when(weatherType.knownEnum()).thenReturn(Type.Known.OBJECT); + when(timeType.knownEnum()).thenReturn(Type.Known.OBJECT); + + // Create mock schema properties + when(weatherSchema.properties()).thenReturn(Optional.of(Map.of("city", weatherSchema))); + when(timeSchema.properties()).thenReturn(Optional.of(Map.of("city", timeSchema))); + when(weatherSchema.required()).thenReturn(Optional.of(List.of("city"))); + when(timeSchema.required()).thenReturn(Optional.of(List.of("city"))); + + // Create LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents( + List.of( + Content.fromParts( + Part.fromText("What's the weather in Paris and the current time?")))) + .build(); + + // Mock multiple tool execution requests in the AI response + final ToolExecutionRequest weatherRequest = + ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final ToolExecutionRequest timeRequest = + ToolExecutionRequest.builder() + .id("456") + .name("getCurrentTime") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final AiMessage aiMessage = + AiMessage.builder() + .text("") + .toolExecutionRequests(List.of(weatherRequest, timeRequest)) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isPresent(); + + final List parts = response.content().get().parts().orElseThrow(); + assertThat(parts).hasSize(2); + + // Verify first function call (getWeather) + assertThat(parts.get(0).functionCall()).isPresent(); + final FunctionCall weatherCall = parts.get(0).functionCall().orElseThrow(); + assertThat(weatherCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(weatherCall.args()).isPresent(); + assertThat(weatherCall.args().get()).containsEntry("city", "Paris"); + + // Verify second function call (getCurrentTime) + assertThat(parts.get(1).functionCall()).isPresent(); + final FunctionCall timeCall = parts.get(1).functionCall().orElseThrow(); + assertThat(timeCall.name()).isEqualTo(Optional.of("getCurrentTime")); + assertThat(timeCall.args()).isPresent(); + assertThat(timeCall.args().get()).containsEntry("city", "Paris"); + + // Verify the ChatModel was called + verify(chatModel).chat(any(ChatRequest.class)); + } + @Test @DisplayName("Should handle streaming responses correctly") void testGenerateContentWithStreamingChatModel() {