Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
target/
.idea
Original file line number Diff line number Diff line change
Expand Up @@ -266,24 +266,25 @@ private ChatRequest toChatRequest(LlmRequest llmRequest) {
}

private List<ChatMessage> toMessages(LlmRequest llmRequest) {
List<ChatMessage> messages = new ArrayList<>();
messages.addAll(llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList());
messages.addAll(llmRequest.contents().stream().map(this::toChatMessage).toList());
List<ChatMessage> messages =
new ArrayList<>(
llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList());
llmRequest.contents().forEach(content -> messages.addAll(toChatMessage(content)));
return messages;
}

private ChatMessage toChatMessage(Content content) {
private List<ChatMessage> toChatMessage(Content content) {
String role = content.role().orElseThrow().toLowerCase();
return switch (role) {
case "user" -> toUserOrToolResultMessage(content);
case "model", "assistant" -> toAiMessage(content);
case "model", "assistant" -> List.of(toAiMessage(content));
default -> throw new IllegalStateException("Unexpected role: " + role);
};
}

private ChatMessage toUserOrToolResultMessage(Content content) {
ToolExecutionResultMessage toolExecutionResultMessage = null;
ToolExecutionRequest toolExecutionRequest = null;
private List<ChatMessage> toUserOrToolResultMessage(Content content) {
List<ToolExecutionResultMessage> toolExecutionResultMessages = new ArrayList<>();
List<ToolExecutionRequest> toolExecutionRequests = new ArrayList<>();

List<dev.langchain4j.data.message.Content> lc4jContents = new ArrayList<>();

Expand All @@ -292,19 +293,19 @@ private ChatMessage toUserOrToolResultMessage(Content content) {
lc4jContents.add(TextContent.from(part.text().get()));
} else if (part.functionResponse().isPresent()) {
FunctionResponse functionResponse = part.functionResponse().get();
toolExecutionResultMessage =
toolExecutionResultMessages.add(
ToolExecutionResultMessage.from(
functionResponse.id().orElseThrow(),
functionResponse.name().orElseThrow(),
toJson(functionResponse.response().orElseThrow()));
toJson(functionResponse.response().orElseThrow())));
} else if (part.functionCall().isPresent()) {
FunctionCall functionCall = part.functionCall().get();
toolExecutionRequest =
toolExecutionRequests.add(
ToolExecutionRequest.builder()
.id(functionCall.id().orElseThrow())
.name(functionCall.name().orElseThrow())
.arguments(toJson(functionCall.args().orElse(Map.of())))
.build();
.build());
} else if (part.inlineData().isPresent()) {
Blob blob = part.inlineData().get();

Expand Down Expand Up @@ -368,12 +369,15 @@ private ChatMessage toUserOrToolResultMessage(Content content) {
}
}

if (toolExecutionResultMessage != null) {
return toolExecutionResultMessage;
} else if (toolExecutionRequest != null) {
return AiMessage.aiMessage(toolExecutionRequest);
if (!toolExecutionResultMessages.isEmpty()) {
return new ArrayList<ChatMessage>(toolExecutionResultMessages);
} else if (!toolExecutionRequests.isEmpty()) {
return toolExecutionRequests.stream()
.map(AiMessage::aiMessage)
.map(msg -> (ChatMessage) msg)
.toList();
} else {
return UserMessage.from(lc4jContents);
return List.of(UserMessage.from(lc4jContents));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,108 @@ void testGenerateContentWithFunctionCall() {
assertThat(functionCall.args().get()).containsEntry("city", "Paris");
}

@Test
@DisplayName("Should handle multiple function calls in LLM responses")
void testGenerateContentWithMultipleFunctionCall() {
// Given
// Create mock FunctionTools
final FunctionTool weatherTool = mock(FunctionTool.class);
when(weatherTool.name()).thenReturn("getWeather");
when(weatherTool.description()).thenReturn("Get weather for a city");

final FunctionTool timeTool = mock(FunctionTool.class);
when(timeTool.name()).thenReturn("getCurrentTime");
when(timeTool.description()).thenReturn("Get current time for a city");

// Create mock FunctionDeclarations
final FunctionDeclaration weatherDeclaration = mock(FunctionDeclaration.class);
final FunctionDeclaration timeDeclaration = mock(FunctionDeclaration.class);
when(weatherTool.declaration()).thenReturn(Optional.of(weatherDeclaration));
when(timeTool.declaration()).thenReturn(Optional.of(timeDeclaration));

// Create mock Schemas
final Schema weatherSchema = mock(Schema.class);
final Schema timeSchema = mock(Schema.class);
when(weatherDeclaration.parameters()).thenReturn(Optional.of(weatherSchema));
when(timeDeclaration.parameters()).thenReturn(Optional.of(timeSchema));

// Create mock Types
final Type weatherType = mock(Type.class);
final Type timeType = mock(Type.class);
when(weatherSchema.type()).thenReturn(Optional.of(weatherType));
when(timeSchema.type()).thenReturn(Optional.of(timeType));
when(weatherType.knownEnum()).thenReturn(Type.Known.OBJECT);
when(timeType.knownEnum()).thenReturn(Type.Known.OBJECT);

// Create mock schema properties
when(weatherSchema.properties()).thenReturn(Optional.of(Map.of("city", weatherSchema)));
when(timeSchema.properties()).thenReturn(Optional.of(Map.of("city", timeSchema)));
when(weatherSchema.required()).thenReturn(Optional.of(List.of("city")));
when(timeSchema.required()).thenReturn(Optional.of(List.of("city")));

// Create LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(
List.of(
Content.fromParts(
Part.fromText("What's the weather in Paris and the current time?"))))
.build();

// Mock multiple tool execution requests in the AI response
final ToolExecutionRequest weatherRequest =
ToolExecutionRequest.builder()
.id("123")
.name("getWeather")
.arguments("{\"city\":\"Paris\"}")
.build();

final ToolExecutionRequest timeRequest =
ToolExecutionRequest.builder()
.id("456")
.name("getCurrentTime")
.arguments("{\"city\":\"Paris\"}")
.build();

final AiMessage aiMessage =
AiMessage.builder()
.text("")
.toolExecutionRequests(List.of(weatherRequest, timeRequest))
.build();

final ChatResponse chatResponse = mock(ChatResponse.class);
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst();

// Then
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().parts()).isPresent();

final List<Part> parts = response.content().get().parts().orElseThrow();
assertThat(parts).hasSize(2);

// Verify first function call (getWeather)
assertThat(parts.get(0).functionCall()).isPresent();
final FunctionCall weatherCall = parts.get(0).functionCall().orElseThrow();
assertThat(weatherCall.name()).isEqualTo(Optional.of("getWeather"));
assertThat(weatherCall.args()).isPresent();
assertThat(weatherCall.args().get()).containsEntry("city", "Paris");

// Verify second function call (getCurrentTime)
assertThat(parts.get(1).functionCall()).isPresent();
final FunctionCall timeCall = parts.get(1).functionCall().orElseThrow();
assertThat(timeCall.name()).isEqualTo(Optional.of("getCurrentTime"));
assertThat(timeCall.args()).isPresent();
assertThat(timeCall.args().get()).containsEntry("city", "Paris");

// Verify the ChatModel was called
verify(chatModel).chat(any(ChatRequest.class));
}

@Test
@DisplayName("Should handle streaming responses correctly")
void testGenerateContentWithStreamingChatModel() {
Expand Down