Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.agents.chatcompletion;

import com.microsoft.semantickernel.Kernel;
Expand All @@ -6,10 +7,10 @@
import com.microsoft.semantickernel.agents.AgentThread;
import com.microsoft.semantickernel.agents.KernelAgent;
import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
import com.microsoft.semantickernel.functionchoice.AutoFunctionChoiceBehavior;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.InvocationReturnMode;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
import com.microsoft.semantickernel.semanticfunctions.KernelArguments;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplate;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig;
Expand Down Expand Up @@ -37,8 +38,7 @@ private ChatCompletionAgent(
KernelArguments kernelArguments,
InvocationContext context,
String instructions,
PromptTemplate template
) {
PromptTemplate template) {
super(
id,
name,
Expand All @@ -47,8 +47,7 @@ private ChatCompletionAgent(
kernelArguments,
context,
instructions,
template
);
template);
}

/**
Expand All @@ -61,70 +60,65 @@ private ChatCompletionAgent(
*/
@Override
public Mono<List<AgentResponseItem<ChatMessageContent<?>>>> invokeAsync(
List<ChatMessageContent<?>> messages,
AgentThread thread,
@Nullable AgentInvokeOptions options
) {
List<ChatMessageContent<?>> messages,
@Nullable AgentThread thread,
@Nullable AgentInvokeOptions options) {
return ensureThreadExistsWithMessagesAsync(messages, thread, ChatHistoryAgentThread::new)
.cast(ChatHistoryAgentThread.class)
.flatMap(agentThread -> {
// Extract the chat history from the thread
ChatHistory history = new ChatHistory(
agentThread.getChatHistory().getMessages()
);

// Invoke the agent with the chat history
return internalInvokeAsync(
history,
options
)
.flatMapMany(Flux::fromIterable)
// notify on the new thread instance
.concatMap(agentMessage -> this.notifyThreadOfNewMessageAsync(agentThread, agentMessage).thenReturn(agentMessage))
.collectList()
.map(chatMessageContents ->
chatMessageContents.stream()
.map(message -> new AgentResponseItem<ChatMessageContent<?>>(message, agentThread))
.collect(Collectors.toList())
);
});
.cast(ChatHistoryAgentThread.class)
.flatMap(agentThread -> {
// Extract the chat history from the thread
ChatHistory history = new ChatHistory(
agentThread.getChatHistory().getMessages());

// Invoke the agent with the chat history
return internalInvokeAsync(
history,
agentThread,
options)
.map(chatMessageContents -> chatMessageContents.stream()
.map(message -> new AgentResponseItem<ChatMessageContent<?>>(message,
agentThread))
.collect(Collectors.toList()));
});
}

private Mono<List<ChatMessageContent<?>>> internalInvokeAsync(
ChatHistory history,
@Nullable AgentInvokeOptions options
) {
AgentThread thread,
@Nullable AgentInvokeOptions options) {
if (options == null) {
options = new AgentInvokeOptions();
}

final Kernel kernel = options.getKernel() != null ? options.getKernel() : this.kernel;
final KernelArguments arguments = mergeArguments(options.getKernelArguments());
final String additionalInstructions = options.getAdditionalInstructions();
final InvocationContext invocationContext = options.getInvocationContext() != null ? options.getInvocationContext() : this.invocationContext;
final InvocationContext invocationContext = options.getInvocationContext() != null
? options.getInvocationContext()
: this.invocationContext;

try {
ChatCompletionService chatCompletionService = kernel.getService(ChatCompletionService.class, arguments);
ChatCompletionService chatCompletionService = kernel
.getService(ChatCompletionService.class, arguments);

PromptExecutionSettings executionSettings = invocationContext != null && invocationContext.getPromptExecutionSettings() != null
PromptExecutionSettings executionSettings = invocationContext != null
&& invocationContext.getPromptExecutionSettings() != null
? invocationContext.getPromptExecutionSettings()
: kernelArguments.getExecutionSettings().get(chatCompletionService.getServiceId());

ToolCallBehavior toolCallBehavior = invocationContext != null
? invocationContext.getToolCallBehavior()
: ToolCallBehavior.allowAllKernelFunctions(true);
: arguments.getExecutionSettings()
.get(chatCompletionService.getServiceId());

// Build base invocation context
InvocationContext.Builder builder = InvocationContext.builder()
.withPromptExecutionSettings(executionSettings)
.withToolCallBehavior(toolCallBehavior)
.withReturnMode(InvocationReturnMode.NEW_MESSAGES_ONLY);
.withPromptExecutionSettings(executionSettings)
.withReturnMode(InvocationReturnMode.NEW_MESSAGES_ONLY);

if (invocationContext != null) {
builder = builder
.withTelemetry(invocationContext.getTelemetry())
.withContextVariableConverter(invocationContext.getContextVariableTypes())
.withKernelHooks(invocationContext.getKernelHooks());
.withTelemetry(invocationContext.getTelemetry())
.withFunctionChoiceBehavior(invocationContext.getFunctionChoiceBehavior())
.withToolCallBehavior(invocationContext.getToolCallBehavior())
.withContextVariableConverter(invocationContext.getContextVariableTypes())
.withKernelHooks(invocationContext.getKernelHooks());
}

InvocationContext agentInvocationContext = builder.build();
Expand All @@ -133,32 +127,65 @@ private Mono<List<ChatMessageContent<?>>> internalInvokeAsync(
instructions -> {
// Create a new chat history with the instructions
ChatHistory chat = new ChatHistory(
instructions
);
instructions);

// Add agent additional instructions
if (additionalInstructions != null) {
chat.addMessage(new ChatMessageContent<>(
AuthorRole.SYSTEM,
additionalInstructions
));
AuthorRole.SYSTEM,
additionalInstructions));
}

// Add the chat history to the new chat
chat.addAll(history);

return chatCompletionService.getChatMessageContentsAsync(chat, kernel, agentInvocationContext);
}
);
// Retrieve the chat message contents asynchronously and notify the thread
if (shouldNotifyFunctionCalls(agentInvocationContext)) {
// Notify all messages including function calls
return chatCompletionService
.getChatMessageContentsAsync(chat, kernel, agentInvocationContext)
.flatMapMany(Flux::fromIterable)
.concatMap(message -> notifyThreadOfNewMessageAsync(thread, message)
.thenReturn(message))
// Filter out function calls and their results
.filter(message -> message.getContent() != null
&& message.getAuthorRole() != AuthorRole.TOOL)
.collect(Collectors.toList());
}

// Return chat completion messages without notifying the thread
// We shouldn't add the function call content to the thread, since
// we don't know if the user will execute the call. They should add it themselves.
return chatCompletionService.getChatMessageContentsAsync(chat, kernel,
agentInvocationContext);
});

} catch (ServiceNotFoundException e) {
return Mono.error(e);
}
}

boolean shouldNotifyFunctionCalls(InvocationContext invocationContext) {
if (invocationContext == null) {
return false;
}

if (invocationContext.getFunctionChoiceBehavior() != null && invocationContext
.getFunctionChoiceBehavior() instanceof AutoFunctionChoiceBehavior) {
return ((AutoFunctionChoiceBehavior) invocationContext.getFunctionChoiceBehavior())
.isAutoInvoke();
}

if (invocationContext.getToolCallBehavior() != null) {
return invocationContext.getToolCallBehavior().isAutoInvokeAllowed();
}

return false;
}

@Override
public Mono<Void> notifyThreadOfNewMessageAsync(AgentThread thread, ChatMessageContent<?> message) {
public Mono<Void> notifyThreadOfNewMessageAsync(AgentThread thread,
ChatMessageContent<?> message) {
return Mono.defer(() -> {
return thread.onNewMessageAsync(message);
});
Expand Down Expand Up @@ -273,11 +300,10 @@ public ChatCompletionAgent build() {
name,
description,
kernel,
kernelArguments,
kernelArguments,
invocationContext,
instructions,
template
);
template);
}

/**
Expand All @@ -287,17 +313,17 @@ public ChatCompletionAgent build() {
* @param promptTemplateFactory The prompt template factory to use.
* @return The ChatCompletionAgent instance.
*/
public ChatCompletionAgent build(PromptTemplateConfig promptTemplateConfig, PromptTemplateFactory promptTemplateFactory) {
public ChatCompletionAgent build(PromptTemplateConfig promptTemplateConfig,
PromptTemplateFactory promptTemplateFactory) {
return new ChatCompletionAgent(
id,
name,
description,
kernel,
kernelArguments,
kernelArguments,
invocationContext,
promptTemplateConfig.getTemplate(),
promptTemplateFactory.tryCreate(promptTemplateConfig)
);
promptTemplateFactory.tryCreate(promptTemplateConfig));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.agents.chatcompletion;

import com.microsoft.semantickernel.agents.AgentThread;
Expand All @@ -16,12 +17,25 @@
public class ChatHistoryAgentThread extends BaseAgentThread {
private ChatHistory chatHistory;

/**
* Constructor for ChatHistoryAgentThread.
*
*/
public ChatHistoryAgentThread() {
this(UUID.randomUUID().toString(), new ChatHistory());
}

/**
* Constructor for com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread.
* Constructor for ChatHistoryAgentThread.
*
* @param chatHistory The chat history.
*/
public ChatHistoryAgentThread(@Nullable ChatHistory chatHistory) {
this(UUID.randomUUID().toString(), chatHistory);
}

/**
* Constructor for ChatHistoryAgentThread.
*
* @param id The ID of the thread.
* @param chatHistory The chat history.
Expand Down Expand Up @@ -76,7 +90,6 @@ public List<ChatMessageContent<?>> getMessages() {
return chatHistory.getMessages();
}


public static Builder builder() {
return new Builder();
}
Expand Down
Loading