Skip to content

Commit d39be65

Browse files
author
Milder Hernandez
authored
Merge pull request #310 from milderhc/function-choice
Add FunctionChoiceBehavior for OpenAI
2 parents 7c7c561 + 730af6c commit d39be65

File tree

35 files changed

+1489
-483
lines changed

35 files changed

+1489
-483
lines changed

agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// Copyright (c) Microsoft. All rights reserved.
12
package com.microsoft.semantickernel.agents.chatcompletion;
23

34
import com.microsoft.semantickernel.Kernel;
@@ -6,10 +7,10 @@
67
import com.microsoft.semantickernel.agents.AgentThread;
78
import com.microsoft.semantickernel.agents.KernelAgent;
89
import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
10+
import com.microsoft.semantickernel.functionchoice.AutoFunctionChoiceBehavior;
911
import com.microsoft.semantickernel.orchestration.InvocationContext;
1012
import com.microsoft.semantickernel.orchestration.InvocationReturnMode;
1113
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
12-
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
1314
import com.microsoft.semantickernel.semanticfunctions.KernelArguments;
1415
import com.microsoft.semantickernel.semanticfunctions.PromptTemplate;
1516
import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig;
@@ -37,8 +38,7 @@ private ChatCompletionAgent(
3738
KernelArguments kernelArguments,
3839
InvocationContext context,
3940
String instructions,
40-
PromptTemplate template
41-
) {
41+
PromptTemplate template) {
4242
super(
4343
id,
4444
name,
@@ -47,8 +47,7 @@ private ChatCompletionAgent(
4747
kernelArguments,
4848
context,
4949
instructions,
50-
template
51-
);
50+
template);
5251
}
5352

5453
/**
@@ -61,70 +60,65 @@ private ChatCompletionAgent(
6160
*/
6261
@Override
6362
public Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(
64-
List<ChatMessageContent<?>> messages,
65-
AgentThread thread,
66-
@Nullable AgentInvokeOptions options
67-
) {
63+
List<ChatMessageContent<?>> messages,
64+
@Nullable AgentThread thread,
65+
@Nullable AgentInvokeOptions options) {
6866
return ensureThreadExistsWithMessagesAsync(messages, thread, ChatHistoryAgentThread::new)
69-
.cast(ChatHistoryAgentThread.class)
70-
.flatMap(agentThread -> {
71-
// Extract the chat history from the thread
72-
ChatHistory history = new ChatHistory(
73-
agentThread.getChatHistory().getMessages()
74-
);
75-
76-
// Invoke the agent with the chat history
77-
return internalInvokeAsync(
78-
history,
79-
options
80-
)
81-
.flatMapMany(Flux::fromIterable)
82-
// notify on the new thread instance
83-
.concatMap(agentMessage -> this.notifyThreadOfNewMessageAsync(agentThread, agentMessage).thenReturn(agentMessage))
84-
.collectList()
85-
.map(chatMessageContents ->
86-
chatMessageContents.stream()
87-
.map(message -> new AgentResponseItem<ChatMessageContent<?>>(message, agentThread))
88-
.collect(Collectors.toList())
89-
);
90-
});
67+
.cast(ChatHistoryAgentThread.class)
68+
.flatMap(agentThread -> {
69+
// Extract the chat history from the thread
70+
ChatHistory history = new ChatHistory(
71+
agentThread.getChatHistory().getMessages());
72+
73+
// Invoke the agent with the chat history
74+
return internalInvokeAsync(
75+
history,
76+
agentThread,
77+
options)
78+
.map(chatMessageContents -> chatMessageContents.stream()
79+
.map(message -> new AgentResponseItem<ChatMessageContent<?>>(message,
80+
agentThread))
81+
.collect(Collectors.toList()));
82+
});
9183
}
9284

9385
private Mono<List<ChatMessageContent<?>>> internalInvokeAsync(
9486
ChatHistory history,
95-
@Nullable AgentInvokeOptions options
96-
) {
87+
AgentThread thread,
88+
@Nullable AgentInvokeOptions options) {
9789
if (options == null) {
9890
options = new AgentInvokeOptions();
9991
}
10092

10193
final Kernel kernel = options.getKernel() != null ? options.getKernel() : this.kernel;
10294
final KernelArguments arguments = mergeArguments(options.getKernelArguments());
10395
final String additionalInstructions = options.getAdditionalInstructions();
104-
final InvocationContext invocationContext = options.getInvocationContext() != null ? options.getInvocationContext() : this.invocationContext;
96+
final InvocationContext invocationContext = options.getInvocationContext() != null
97+
? options.getInvocationContext()
98+
: this.invocationContext;
10599

106100
try {
107-
ChatCompletionService chatCompletionService = kernel.getService(ChatCompletionService.class, arguments);
101+
ChatCompletionService chatCompletionService = kernel
102+
.getService(ChatCompletionService.class, arguments);
108103

109-
PromptExecutionSettings executionSettings = invocationContext != null && invocationContext.getPromptExecutionSettings() != null
104+
PromptExecutionSettings executionSettings = invocationContext != null
105+
&& invocationContext.getPromptExecutionSettings() != null
110106
? invocationContext.getPromptExecutionSettings()
111-
: kernelArguments.getExecutionSettings().get(chatCompletionService.getServiceId());
112-
113-
ToolCallBehavior toolCallBehavior = invocationContext != null
114-
? invocationContext.getToolCallBehavior()
115-
: ToolCallBehavior.allowAllKernelFunctions(true);
107+
: arguments.getExecutionSettings()
108+
.get(chatCompletionService.getServiceId());
116109

117110
// Build base invocation context
118111
InvocationContext.Builder builder = InvocationContext.builder()
119-
.withPromptExecutionSettings(executionSettings)
120-
.withToolCallBehavior(toolCallBehavior)
121-
.withReturnMode(InvocationReturnMode.NEW_MESSAGES_ONLY);
112+
.withPromptExecutionSettings(executionSettings)
113+
.withReturnMode(InvocationReturnMode.NEW_MESSAGES_ONLY);
122114

123115
if (invocationContext != null) {
124116
builder = builder
125-
.withTelemetry(invocationContext.getTelemetry())
126-
.withContextVariableConverter(invocationContext.getContextVariableTypes())
127-
.withKernelHooks(invocationContext.getKernelHooks());
117+
.withTelemetry(invocationContext.getTelemetry())
118+
.withFunctionChoiceBehavior(invocationContext.getFunctionChoiceBehavior())
119+
.withToolCallBehavior(invocationContext.getToolCallBehavior())
120+
.withContextVariableConverter(invocationContext.getContextVariableTypes())
121+
.withKernelHooks(invocationContext.getKernelHooks());
128122
}
129123

130124
InvocationContext agentInvocationContext = builder.build();
@@ -133,32 +127,65 @@ private Mono<List<ChatMessageContent<?>>> internalInvokeAsync(
133127
instructions -> {
134128
// Create a new chat history with the instructions
135129
ChatHistory chat = new ChatHistory(
136-
instructions
137-
);
130+
instructions);
138131

139132
// Add agent additional instructions
140133
if (additionalInstructions != null) {
141134
chat.addMessage(new ChatMessageContent<>(
142-
AuthorRole.SYSTEM,
143-
additionalInstructions
144-
));
135+
AuthorRole.SYSTEM,
136+
additionalInstructions));
145137
}
146138

147139
// Add the chat history to the new chat
148140
chat.addAll(history);
149141

150-
return chatCompletionService.getChatMessageContentsAsync(chat, kernel, agentInvocationContext);
151-
}
152-
);
142+
// Retrieve the chat message contents asynchronously and notify the thread
143+
if (shouldNotifyFunctionCalls(agentInvocationContext)) {
144+
// Notify all messages including function calls
145+
return chatCompletionService
146+
.getChatMessageContentsAsync(chat, kernel, agentInvocationContext)
147+
.flatMapMany(Flux::fromIterable)
148+
.concatMap(message -> notifyThreadOfNewMessageAsync(thread, message)
149+
.thenReturn(message))
150+
// Filter out function calls and their results
151+
.filter(message -> message.getContent() != null
152+
&& message.getAuthorRole() != AuthorRole.TOOL)
153+
.collect(Collectors.toList());
154+
}
155+
156+
// Return chat completion messages without notifying the thread
157+
// We shouldn't add the function call content to the thread, since
158+
// we don't know if the user will execute the call. They should add it themselves.
159+
return chatCompletionService.getChatMessageContentsAsync(chat, kernel,
160+
agentInvocationContext);
161+
});
153162

154163
} catch (ServiceNotFoundException e) {
155164
return Mono.error(e);
156165
}
157166
}
158167

168+
boolean shouldNotifyFunctionCalls(InvocationContext invocationContext) {
169+
if (invocationContext == null) {
170+
return false;
171+
}
172+
173+
if (invocationContext.getFunctionChoiceBehavior() != null && invocationContext
174+
.getFunctionChoiceBehavior() instanceof AutoFunctionChoiceBehavior) {
175+
return ((AutoFunctionChoiceBehavior) invocationContext.getFunctionChoiceBehavior())
176+
.isAutoInvoke();
177+
}
178+
179+
if (invocationContext.getToolCallBehavior() != null) {
180+
return invocationContext.getToolCallBehavior().isAutoInvokeAllowed();
181+
}
182+
183+
return false;
184+
}
159185

160186
@Override
161-
public Mono<Void> notifyThreadOfNewMessageAsync(AgentThread thread, ChatMessageContent<?> message) {
187+
public Mono<Void> notifyThreadOfNewMessageAsync(AgentThread thread,
188+
ChatMessageContent<?> message) {
162189
return Mono.defer(() -> {
163190
return thread.onNewMessageAsync(message);
164191
});
@@ -273,11 +300,10 @@ public ChatCompletionAgent build() {
273300
name,
274301
description,
275302
kernel,
276-
kernelArguments,
303+
kernelArguments,
277304
invocationContext,
278305
instructions,
279-
template
280-
);
306+
template);
281307
}
282308

283309
/**
@@ -287,17 +313,17 @@ public ChatCompletionAgent build() {
287313
* @param promptTemplateFactory The prompt template factory to use.
288314
* @return The ChatCompletionAgent instance.
289315
*/
290-
public ChatCompletionAgent build(PromptTemplateConfig promptTemplateConfig, PromptTemplateFactory promptTemplateFactory) {
316+
public ChatCompletionAgent build(PromptTemplateConfig promptTemplateConfig,
317+
PromptTemplateFactory promptTemplateFactory) {
291318
return new ChatCompletionAgent(
292319
id,
293320
name,
294321
description,
295322
kernel,
296-
kernelArguments,
323+
kernelArguments,
297324
invocationContext,
298325
promptTemplateConfig.getTemplate(),
299-
promptTemplateFactory.tryCreate(promptTemplateConfig)
300-
);
326+
promptTemplateFactory.tryCreate(promptTemplateConfig));
301327
}
302328
}
303329
}

agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// Copyright (c) Microsoft. All rights reserved.
12
package com.microsoft.semantickernel.agents.chatcompletion;
23

34
import com.microsoft.semantickernel.agents.AgentThread;
@@ -16,12 +17,25 @@
1617
public class ChatHistoryAgentThread extends BaseAgentThread {
1718
private ChatHistory chatHistory;
1819

20+
/**
21+
* Constructor for ChatHistoryAgentThread.
22+
*
23+
*/
1924
public ChatHistoryAgentThread() {
2025
this(UUID.randomUUID().toString(), new ChatHistory());
2126
}
2227

2328
/**
24-
* Constructor for com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread.
29+
* Constructor for ChatHistoryAgentThread.
30+
*
31+
* @param chatHistory The chat history.
32+
*/
33+
public ChatHistoryAgentThread(@Nullable ChatHistory chatHistory) {
34+
this(UUID.randomUUID().toString(), chatHistory);
35+
}
36+
37+
/**
38+
* Constructor for ChatHistoryAgentThread.
2539
*
2640
* @param id The ID of the thread.
2741
* @param chatHistory The chat history.
@@ -76,7 +90,6 @@ public List<ChatMessageContent<?>> getMessages() {
7690
return chatHistory.getMessages();
7791
}
7892

79-
8093
public static Builder builder() {
8194
return new Builder();
8295
}

0 commit comments

Comments
 (0)