Skip to content

Commit 7ff451c

Browse files
committed
change Agent interface to return a list of messages
1 parent a212fdc commit 7ff451c

File tree

4 files changed

+19
-18
lines changed

4 files changed

+19
-18
lines changed

app/copilot/copilot-backend/src/main/java/com/microsoft/openai/samples/assistant/controller/ChatController.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ public ResponseEntity<ChatResponse> openAIAsk(@RequestBody ChatAppRequest chatRe
5353

5454
LOGGER.debug("Processing chat conversation..", chatHistory.get(chatHistory.size()-1));
5555

56-
supervisorAgent.invoke(chatHistory);
56+
List<ChatMessage> agentsResponse = supervisorAgent.invoke(chatHistory);
5757

58-
AiMessage generatedResponse = (AiMessage) chatHistory.get(chatHistory.size()-1);
58+
AiMessage generatedResponse = (AiMessage) agentsResponse.get(agentsResponse.size()-1);
5959
return ResponseEntity.ok(
6060
ChatResponse.buildChatResponse(generatedResponse));
6161
}

app/copilot/langchain4j-agents/src/main/java/com/microsoft/langchain4j/agent/AbstractReActAgent.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import java.util.ArrayList;
1818
import java.util.List;
19+
import java.util.stream.Collectors;
1920

2021
public abstract class AbstractReActAgent implements Agent {
2122

@@ -31,7 +32,7 @@ protected AbstractReActAgent(ChatLanguageModel chatModel) {
3132
}
3233

3334
@Override
34-
public void invoke(List<ChatMessage> chatHistory) throws AgentExecutionException {
35+
public List<ChatMessage> invoke(List<ChatMessage> chatHistory) throws AgentExecutionException {
3536
LOGGER.info("------------- {} -------------", this.getName());
3637

3738
try {
@@ -67,20 +68,17 @@ public void invoke(List<ChatMessage> chatHistory) throws AgentExecutionException
6768

6869
// add last ai message to agent internal memory
6970
internalChatMemory.add(aiMessage);
70-
updateChatHistory(chatHistory, internalChatMemory);
71+
return buildResponse(chatHistory, internalChatMemory);
7172
} catch (Exception e) {
7273
throw new AgentExecutionException("Error during agent [%s] invocation".formatted(this.getName()), e);
7374
}
7475
}
7576

76-
protected void updateChatHistory(List<ChatMessage> chatHistory, ChatMemory internalChatMemory) {
77-
//delete extenal messages to avoid duplication
78-
chatHistory.clear();
79-
//add previous history + agent internal messages
80-
internalChatMemory.messages()
81-
.stream()
82-
.filter(m -> !(m instanceof SystemMessage))
83-
.forEach(chatHistory::add);
77+
protected List<ChatMessage> buildResponse(List<ChatMessage> chatHistory, ChatMemory internalChatMemory) {
78+
return internalChatMemory.messages()
79+
.stream()
80+
.filter(m -> !(m instanceof SystemMessage))
81+
.collect(Collectors.toList());
8482
}
8583

8684
protected List<ToolExecutionResultMessage> executeToolRequests(List<ToolExecutionRequest> toolExecutionRequests) {

app/copilot/langchain4j-agents/src/main/java/com/microsoft/langchain4j/agent/Agent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ public interface Agent {
88

99
String getName();
1010
AgentMetadata getMetadata();
11-
void invoke(List<ChatMessage> chatHistory) throws AgentExecutionException;
11+
List<ChatMessage> invoke(List<ChatMessage> chatHistory) throws AgentExecutionException;
1212
}

app/copilot/langchain4j-agents/src/main/java/com/microsoft/openai/samples/assistant/langchain4j/agent/SupervisorAgent.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.slf4j.Logger;
1717
import org.slf4j.LoggerFactory;
1818

19+
import java.util.ArrayList;
1920
import java.util.List;
2021
import java.util.Map;
2122
import java.util.stream.Collectors;
@@ -55,7 +56,7 @@ public SupervisorAgent(ChatLanguageModel chatLanguageModel, List<Agent> agents)
5556
}
5657

5758

58-
public void invoke(List<ChatMessage> chatHistory) {
59+
public List<ChatMessage> invoke(List<ChatMessage> chatHistory) {
5960
LOGGER.info("------------- SupervisorAgent -------------");
6061

6162
var internalChatMemory = buildInternalChat(chatHistory);
@@ -69,27 +70,29 @@ public void invoke(List<ChatMessage> chatHistory) {
6970
LOGGER.info("Supervisor Agent handoff to [{}]", nextAgent);
7071

7172
if (routing) {
72-
singleTurnRouting(nextAgent, chatHistory);
73+
return singleTurnRouting(nextAgent, chatHistory);
7374
}
75+
76+
return new ArrayList<>();
7477
}
7578

7679

77-
protected void singleTurnRouting(String nextAgent, List<ChatMessage> chatHistory) {
80+
protected List<ChatMessage> singleTurnRouting(String nextAgent, List<ChatMessage> chatHistory) {
7881
if("none".equalsIgnoreCase(nextAgent)){
7982
LOGGER.info("Gracefully handle clarification.. ");
8083
AiMessage clarificationMessage = AiMessage.builder().
8184
text(" I'm not sure about your request. Can you please clarify?")
8285
.build();
8386
chatHistory.add(clarificationMessage);
84-
return;
87+
return chatHistory;
8588
}
8689

8790
Agent agent = agents.stream()
8891
.filter(a -> a.getName().equals(nextAgent))
8992
.findFirst()
9093
.orElseThrow(() -> new AgentExecutionException("Agent not found: " + nextAgent));
9194

92-
agent.invoke(chatHistory);
95+
return agent.invoke(chatHistory);
9396
}
9497

9598

0 commit comments

Comments
 (0)