diff --git a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java index f8423f3d..b5294fe6 100644 --- a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java +++ b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatCompletionAgent.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents.chatcompletion; import com.microsoft.semantickernel.Kernel; @@ -6,10 +7,10 @@ import com.microsoft.semantickernel.agents.AgentThread; import com.microsoft.semantickernel.agents.KernelAgent; import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.functionchoice.AutoFunctionChoiceBehavior; import com.microsoft.semantickernel.orchestration.InvocationContext; import com.microsoft.semantickernel.orchestration.InvocationReturnMode; import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; -import com.microsoft.semantickernel.orchestration.ToolCallBehavior; import com.microsoft.semantickernel.semanticfunctions.KernelArguments; import com.microsoft.semantickernel.semanticfunctions.PromptTemplate; import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig; @@ -37,8 +38,7 @@ private ChatCompletionAgent( KernelArguments kernelArguments, InvocationContext context, String instructions, - PromptTemplate template - ) { + PromptTemplate template) { super( id, name, @@ -47,8 +47,7 @@ private ChatCompletionAgent( kernelArguments, context, instructions, - template - ); + template); } /** @@ -61,39 +60,32 @@ private ChatCompletionAgent( */ @Override public Mono>>> invokeAsync( - List> messages, - AgentThread thread, - @Nullable AgentInvokeOptions options - ) { + List> messages, + @Nullable AgentThread thread, + @Nullable AgentInvokeOptions options) { return ensureThreadExistsWithMessagesAsync(messages, thread, ChatHistoryAgentThread::new) - .cast(ChatHistoryAgentThread.class) - .flatMap(agentThread -> { - // Extract the chat history from the thread - ChatHistory history = new ChatHistory( - agentThread.getChatHistory().getMessages() - ); - - // Invoke the agent with the chat history - return internalInvokeAsync( - history, - options - ) - .flatMapMany(Flux::fromIterable) - // notify on the new thread instance - .concatMap(agentMessage -> this.notifyThreadOfNewMessageAsync(agentThread, agentMessage).thenReturn(agentMessage)) - .collectList() - .map(chatMessageContents -> - chatMessageContents.stream() - .map(message -> new AgentResponseItem>(message, agentThread)) - .collect(Collectors.toList()) - ); - }); + .cast(ChatHistoryAgentThread.class) + .flatMap(agentThread -> { + // Extract the chat history from the thread + ChatHistory history = new ChatHistory( + agentThread.getChatHistory().getMessages()); + + // Invoke the agent with the chat history + return internalInvokeAsync( + history, + agentThread, + options) + .map(chatMessageContents -> chatMessageContents.stream() + .map(message -> new AgentResponseItem>(message, + agentThread)) + .collect(Collectors.toList())); + }); } private Mono>> internalInvokeAsync( ChatHistory history, - @Nullable AgentInvokeOptions options - ) { + AgentThread thread, + @Nullable AgentInvokeOptions options) { if (options == null) { options = new AgentInvokeOptions(); } @@ -101,30 +93,32 @@ private Mono>> internalInvokeAsync( final Kernel kernel = options.getKernel() != null ? options.getKernel() : this.kernel; final KernelArguments arguments = mergeArguments(options.getKernelArguments()); final String additionalInstructions = options.getAdditionalInstructions(); - final InvocationContext invocationContext = options.getInvocationContext() != null ? options.getInvocationContext() : this.invocationContext; + final InvocationContext invocationContext = options.getInvocationContext() != null + ? options.getInvocationContext() + : this.invocationContext; try { - ChatCompletionService chatCompletionService = kernel.getService(ChatCompletionService.class, arguments); + ChatCompletionService chatCompletionService = kernel + .getService(ChatCompletionService.class, arguments); - PromptExecutionSettings executionSettings = invocationContext != null && invocationContext.getPromptExecutionSettings() != null + PromptExecutionSettings executionSettings = invocationContext != null + && invocationContext.getPromptExecutionSettings() != null ? invocationContext.getPromptExecutionSettings() - : kernelArguments.getExecutionSettings().get(chatCompletionService.getServiceId()); - - ToolCallBehavior toolCallBehavior = invocationContext != null - ? invocationContext.getToolCallBehavior() - : ToolCallBehavior.allowAllKernelFunctions(true); + : arguments.getExecutionSettings() + .get(chatCompletionService.getServiceId()); // Build base invocation context InvocationContext.Builder builder = InvocationContext.builder() - .withPromptExecutionSettings(executionSettings) - .withToolCallBehavior(toolCallBehavior) - .withReturnMode(InvocationReturnMode.NEW_MESSAGES_ONLY); + .withPromptExecutionSettings(executionSettings) + .withReturnMode(InvocationReturnMode.NEW_MESSAGES_ONLY); if (invocationContext != null) { builder = builder - .withTelemetry(invocationContext.getTelemetry()) - .withContextVariableConverter(invocationContext.getContextVariableTypes()) - .withKernelHooks(invocationContext.getKernelHooks()); + .withTelemetry(invocationContext.getTelemetry()) + .withFunctionChoiceBehavior(invocationContext.getFunctionChoiceBehavior()) + .withToolCallBehavior(invocationContext.getToolCallBehavior()) + .withContextVariableConverter(invocationContext.getContextVariableTypes()) + .withKernelHooks(invocationContext.getKernelHooks()); } InvocationContext agentInvocationContext = builder.build(); @@ -133,32 +127,65 @@ private Mono>> internalInvokeAsync( instructions -> { // Create a new chat history with the instructions ChatHistory chat = new ChatHistory( - instructions - ); + instructions); // Add agent additional instructions if (additionalInstructions != null) { chat.addMessage(new ChatMessageContent<>( - AuthorRole.SYSTEM, - additionalInstructions - )); + AuthorRole.SYSTEM, + additionalInstructions)); } // Add the chat history to the new chat chat.addAll(history); - return chatCompletionService.getChatMessageContentsAsync(chat, kernel, agentInvocationContext); - } - ); + // Retrieve the chat message contents asynchronously and notify the thread + if (shouldNotifyFunctionCalls(agentInvocationContext)) { + // Notify all messages including function calls + return chatCompletionService + .getChatMessageContentsAsync(chat, kernel, agentInvocationContext) + .flatMapMany(Flux::fromIterable) + .concatMap(message -> notifyThreadOfNewMessageAsync(thread, message) + .thenReturn(message)) + // Filter out function calls and their results + .filter(message -> message.getContent() != null + && message.getAuthorRole() != AuthorRole.TOOL) + .collect(Collectors.toList()); + } + + // Return chat completion messages without notifying the thread + // We shouldn't add the function call content to the thread, since + // we don't know if the user will execute the call. They should add it themselves. + return chatCompletionService.getChatMessageContentsAsync(chat, kernel, + agentInvocationContext); + }); } catch (ServiceNotFoundException e) { return Mono.error(e); } } + boolean shouldNotifyFunctionCalls(InvocationContext invocationContext) { + if (invocationContext == null) { + return false; + } + + if (invocationContext.getFunctionChoiceBehavior() != null && invocationContext + .getFunctionChoiceBehavior() instanceof AutoFunctionChoiceBehavior) { + return ((AutoFunctionChoiceBehavior) invocationContext.getFunctionChoiceBehavior()) + .isAutoInvoke(); + } + + if (invocationContext.getToolCallBehavior() != null) { + return invocationContext.getToolCallBehavior().isAutoInvokeAllowed(); + } + + return false; + } @Override - public Mono notifyThreadOfNewMessageAsync(AgentThread thread, ChatMessageContent message) { + public Mono notifyThreadOfNewMessageAsync(AgentThread thread, + ChatMessageContent message) { return Mono.defer(() -> { return thread.onNewMessageAsync(message); }); @@ -273,11 +300,10 @@ public ChatCompletionAgent build() { name, description, kernel, - kernelArguments, + kernelArguments, invocationContext, instructions, - template - ); + template); } /** @@ -287,17 +313,17 @@ public ChatCompletionAgent build() { * @param promptTemplateFactory The prompt template factory to use. * @return The ChatCompletionAgent instance. */ - public ChatCompletionAgent build(PromptTemplateConfig promptTemplateConfig, PromptTemplateFactory promptTemplateFactory) { + public ChatCompletionAgent build(PromptTemplateConfig promptTemplateConfig, + PromptTemplateFactory promptTemplateFactory) { return new ChatCompletionAgent( id, name, description, kernel, - kernelArguments, + kernelArguments, invocationContext, promptTemplateConfig.getTemplate(), - promptTemplateFactory.tryCreate(promptTemplateConfig) - ); + promptTemplateFactory.tryCreate(promptTemplateConfig)); } } } diff --git a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java index 1a68f8c4..6b3f62a9 100644 --- a/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java +++ b/agents/semantickernel-agents-core/src/main/java/com/microsoft/semantickernel/agents/chatcompletion/ChatHistoryAgentThread.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents.chatcompletion; import com.microsoft.semantickernel.agents.AgentThread; @@ -16,12 +17,25 @@ public class ChatHistoryAgentThread extends BaseAgentThread { private ChatHistory chatHistory; + /** + * Constructor for ChatHistoryAgentThread. + * + */ public ChatHistoryAgentThread() { this(UUID.randomUUID().toString(), new ChatHistory()); } /** - * Constructor for com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread. + * Constructor for ChatHistoryAgentThread. + * + * @param chatHistory The chat history. + */ + public ChatHistoryAgentThread(@Nullable ChatHistory chatHistory) { + this(UUID.randomUUID().toString(), chatHistory); + } + + /** + * Constructor for ChatHistoryAgentThread. * * @param id The ID of the thread. * @param chatHistory The chat history. @@ -76,7 +90,6 @@ public List> getMessages() { return chatHistory.getMessages(); } - public static Builder builder() { return new Builder(); } diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 33f46fc2..1db5bd38 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -44,6 +44,10 @@ import com.microsoft.semantickernel.exceptions.AIException.ErrorCodes; import com.microsoft.semantickernel.exceptions.SKCheckedException; import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.functionchoice.AutoFunctionChoiceBehavior; +import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehavior; +import com.microsoft.semantickernel.functionchoice.NoneFunctionChoiceBehavior; +import com.microsoft.semantickernel.functionchoice.RequiredFunctionChoiceBehavior; import com.microsoft.semantickernel.hooks.KernelHookEvent; import com.microsoft.semantickernel.hooks.KernelHooks; import com.microsoft.semantickernel.hooks.PostChatCompletionEvent; @@ -196,10 +200,20 @@ public Flux> getStreamingChatMessageContentsAsync( ChatHistory chatHistory, @Nullable Kernel kernel, @Nullable InvocationContext invocationContext) { - if (invocationContext != null && invocationContext.getToolCallBehavior() - .isAutoInvokeAllowed()) { + if (invocationContext != null && + invocationContext.getToolCallBehavior() != null && + invocationContext.getToolCallBehavior().isAutoInvokeAllowed()) { throw new SKException( - "Auto invoke is not supported for streaming chat message contents"); + "ToolCallBehavior auto-invoke is not supported for streaming chat message contents"); + } + + if (invocationContext != null && + invocationContext.getFunctionChoiceBehavior() != null && + invocationContext.getFunctionChoiceBehavior() instanceof AutoFunctionChoiceBehavior && + ((AutoFunctionChoiceBehavior) invocationContext.getFunctionChoiceBehavior()) + .isAutoInvoke()) { + throw new SKException( + "FunctionChoiceBehavior auto-invoke is not supported for streaming chat message contents"); } if (invocationContext != null @@ -219,6 +233,12 @@ public Flux> getStreamingChatMessageContentsAsync( .add(OpenAIFunction.build(function.getMetadata(), plugin.getName())))); } + OpenAIToolCallConfig toolCallConfig = getToolCallConfig( + invocationContext, + functions, + messages.allMessages, + 0); + ChatCompletionsOptions options = executeHook( invocationContext, kernel, @@ -226,8 +246,8 @@ public Flux> getStreamingChatMessageContentsAsync( getCompletionsOptions( this, messages.allMessages, - functions, - invocationContext))) + invocationContext, + toolCallConfig))) .getOptions(); return getClient() @@ -389,16 +409,12 @@ private Mono internalChatMessageContentsAsync( .add(OpenAIFunction.build(function.getMetadata(), plugin.getName())))); } - // Create copy to avoid reactor exceptions when updating request messages internally return internalChatMessageContentsAsync( messages, kernel, functions, invocationContext, - Math.min(MAXIMUM_INFLIGHT_AUTO_INVOKES, - invocationContext != null && invocationContext.getToolCallBehavior() != null - ? invocationContext.getToolCallBehavior().getMaximumAutoInvokeAttempts() - : 0)); + 0); } private Mono internalChatMessageContentsAsync( @@ -406,7 +422,13 @@ private Mono internalChatMessageContentsAsync( @Nullable Kernel kernel, List functions, @Nullable InvocationContext invocationContext, - int autoInvokeAttempts) { + int requestIndex) { + + OpenAIToolCallConfig toolCallConfig = getToolCallConfig( + invocationContext, + functions, + messages.allMessages, + requestIndex); ChatCompletionsOptions options = executeHook( invocationContext, @@ -415,8 +437,8 @@ private Mono internalChatMessageContentsAsync( getCompletionsOptions( this, messages.allMessages, - functions, - invocationContext))) + invocationContext, + toolCallConfig))) .getOptions(); return Mono.deferContextual(contextView -> { @@ -458,9 +480,10 @@ private Mono internalChatMessageContentsAsync( executeHook(invocationContext, kernel, new PostChatCompletionEvent(completions)); // Just return the result: - // If we don't want to attempt to invoke any functions + // If auto-invoking is not enabled // Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested - if (autoInvokeAttempts == 0 || responseMessages.size() != 1) { + if (toolCallConfig == null || !toolCallConfig.isAutoInvoke() + || responseMessages.size() != 1) { List> chatMessageContents = getChatMessageContentsAsync( completions); return Mono.just(messages.addChatMessage(chatMessageContents)); @@ -497,14 +520,14 @@ private Mono internalChatMessageContentsAsync( .flatMap(it -> it) .flatMap(msgs -> { return internalChatMessageContentsAsync(msgs, kernel, functions, - invocationContext, autoInvokeAttempts - 1); + invocationContext, requestIndex + 1); }) .onErrorResume(e -> { LOGGER.warn("Tool invocation attempt failed: ", e); // If FunctionInvocationError occurred and there are still attempts left, retry, else exit - if (autoInvokeAttempts > 0) { + if (requestIndex < MAXIMUM_INFLIGHT_AUTO_INVOKES) { ChatMessages currentMessages = messages; if (e instanceof FunctionInvocationError) { currentMessages.assertCommonHistory( @@ -518,7 +541,7 @@ private Mono internalChatMessageContentsAsync( kernel, functions, invocationContext, - autoInvokeAttempts - 1); + requestIndex + 1); } else { return Mono.error(e); } @@ -860,8 +883,8 @@ private List formOpenAiToolCalls( private static ChatCompletionsOptions getCompletionsOptions( ChatCompletionService chatCompletionService, List chatRequestMessages, - @Nullable List functions, - @Nullable InvocationContext invocationContext) { + @Nullable InvocationContext invocationContext, + @Nullable OpenAIToolCallConfig toolCallConfig) { chatRequestMessages = chatRequestMessages .stream() @@ -871,12 +894,13 @@ private static ChatCompletionsOptions getCompletionsOptions( ChatCompletionsOptions options = new ChatCompletionsOptions(chatRequestMessages) .setModel(chatCompletionService.getModelId()); - if (invocationContext != null && invocationContext.getToolCallBehavior() != null) { - configureToolCallBehaviorOptions( - options, - invocationContext.getToolCallBehavior(), - functions, - chatRequestMessages); + if (toolCallConfig != null) { + options.setTools(toolCallConfig.getTools()); + options.setToolChoice(toolCallConfig.getToolChoice()); + + if (toolCallConfig.getOptions() != null) { + options.setParallelToolCalls(toolCallConfig.getOptions().isParallelCallsAllowed()); + } } PromptExecutionSettings promptExecutionSettings = invocationContext != null @@ -946,92 +970,184 @@ private static ChatCompletionsOptions getCompletionsOptions( return options; } - private static void configureToolCallBehaviorOptions( - ChatCompletionsOptions options, + @Nullable + private static OpenAIToolCallConfig getToolCallConfig( + @Nullable InvocationContext invocationContext, + @Nullable List functions, + List chatRequestMessages, + int requestIndex) { + + if (invocationContext == null || functions == null || functions.isEmpty()) { + return null; + } + + if (invocationContext.getFunctionChoiceBehavior() == null + && invocationContext.getToolCallBehavior() == null) { + return null; + } + + if (invocationContext.getFunctionChoiceBehavior() != null) { + return getFunctionChoiceBehaviorConfig( + invocationContext.getFunctionChoiceBehavior(), + functions, + requestIndex); + } else { + return getToolCallBehaviorConfig( + invocationContext.getToolCallBehavior(), + functions, + chatRequestMessages, + requestIndex); + } + } + + @Nullable + private static OpenAIToolCallConfig getFunctionChoiceBehaviorConfig( + @Nullable FunctionChoiceBehavior functionChoiceBehavior, + @Nullable List functions, + int requestIndex) { + if (functionChoiceBehavior == null) { + return null; + } + + if (functions == null || functions.isEmpty()) { + return null; + } + + ChatCompletionsToolSelection toolChoice; + boolean autoInvoke; + + if (functionChoiceBehavior instanceof RequiredFunctionChoiceBehavior) { + // After first request a required function must have been called already + if (requestIndex >= 1) { + return null; + } + + toolChoice = new ChatCompletionsToolSelection( + ChatCompletionsToolSelectionPreset.REQUIRED); + autoInvoke = ((RequiredFunctionChoiceBehavior) functionChoiceBehavior).isAutoInvoke(); + } else if (functionChoiceBehavior instanceof AutoFunctionChoiceBehavior) { + toolChoice = new ChatCompletionsToolSelection(ChatCompletionsToolSelectionPreset.AUTO); + autoInvoke = ((AutoFunctionChoiceBehavior) functionChoiceBehavior).isAutoInvoke() + && requestIndex < MAXIMUM_INFLIGHT_AUTO_INVOKES; + } else if (functionChoiceBehavior instanceof NoneFunctionChoiceBehavior) { + toolChoice = new ChatCompletionsToolSelection(ChatCompletionsToolSelectionPreset.NONE); + autoInvoke = false; + } else { + throw new SKException( + "Unsupported function choice behavior: " + functionChoiceBehavior); + } + + // List of functions advertised to the model + List toolDefinitions = functions.stream() + .filter(function -> functionChoiceBehavior.isFunctionAllowed(function.getPluginName(), + function.getName())) + .map(OpenAIFunction::getFunctionDefinition) + .map(it -> new ChatCompletionsFunctionToolDefinitionFunction(it.getName()) + .setDescription(it.getDescription()) + .setParameters(it.getParameters())) + .map(ChatCompletionsFunctionToolDefinition::new) + .collect(Collectors.toList()); + + return new OpenAIToolCallConfig( + toolDefinitions, + toolChoice, + autoInvoke, + functionChoiceBehavior.getOptions()); + } + + @Nullable + private static OpenAIToolCallConfig getToolCallBehaviorConfig( @Nullable ToolCallBehavior toolCallBehavior, @Nullable List functions, - List chatRequestMessages) { + List chatRequestMessages, + int requestIndex) { if (toolCallBehavior == null) { - return; + return null; } if (functions == null || functions.isEmpty()) { - return; + return null; } + List toolDefinitions; + ChatCompletionsToolSelection toolChoice; + // If a specific function is required to be called if (toolCallBehavior instanceof ToolCallBehavior.RequiredKernelFunction) { - KernelFunction toolChoice = ((ToolCallBehavior.RequiredKernelFunction) toolCallBehavior) + KernelFunction requiredFunction = ((ToolCallBehavior.RequiredKernelFunction) toolCallBehavior) .getRequiredFunction(); String toolChoiceName = String.format("%s%s%s", - toolChoice.getPluginName(), + requiredFunction.getPluginName(), OpenAIFunction.getNameSeparator(), - toolChoice.getName()); + requiredFunction.getName()); // If required tool call has already been called dont ask for it again boolean hasBeenExecuted = hasToolCallBeenExecuted(chatRequestMessages, toolChoiceName); if (hasBeenExecuted) { - return; + return null; } - List toolDefinitions = new ArrayList<>(); - FunctionDefinition function = OpenAIFunction.toFunctionDefinition( - toolChoice.getMetadata(), - toolChoice.getPluginName()); + requiredFunction.getMetadata(), + requiredFunction.getPluginName()); + toolDefinitions = new ArrayList<>(); toolDefinitions.add(new ChatCompletionsFunctionToolDefinition( new ChatCompletionsFunctionToolDefinitionFunction(function.getName()) .setDescription(function.getDescription()) .setParameters(function.getParameters()))); - options.setTools(toolDefinitions); try { String json = String.format( "{\"type\":\"function\",\"function\":{\"name\":\"%s\"}}", toolChoiceName); - options.setToolChoice( - new ChatCompletionsToolSelection( - ChatCompletionsNamedToolSelection.fromJson( - DefaultJsonReader.fromString( - json, - new JsonOptions())))); + toolChoice = new ChatCompletionsToolSelection( + ChatCompletionsNamedToolSelection.fromJson( + DefaultJsonReader.fromString( + json, + new JsonOptions()))); } catch (JsonProcessingException e) { throw SKException.build("Failed to parse tool choice", e); } catch (IOException e) { throw new SKException(e); } - return; } - // If a set of functions are enabled to be called - ToolCallBehavior.AllowedKernelFunctions enabledKernelFunctions = (ToolCallBehavior.AllowedKernelFunctions) toolCallBehavior; - List toolDefinitions = functions.stream() - .filter(function -> { - // check if all kernel functions are enabled - if (enabledKernelFunctions.isAllKernelFunctionsAllowed()) { - return true; - } - // otherwise, check for the specific function - return enabledKernelFunctions.isFunctionAllowed(function.getPluginName(), - function.getName()); - }) - .map(OpenAIFunction::getFunctionDefinition) - .map(it -> new ChatCompletionsFunctionToolDefinitionFunction(it.getName()) - .setDescription(it.getDescription()) - .setParameters(it.getParameters())) - .map(it -> new ChatCompletionsFunctionToolDefinition(it)) - .collect(Collectors.toList()); + else { + toolChoice = new ChatCompletionsToolSelection(ChatCompletionsToolSelectionPreset.AUTO); + + ToolCallBehavior.AllowedKernelFunctions enabledKernelFunctions = (ToolCallBehavior.AllowedKernelFunctions) toolCallBehavior; + toolDefinitions = functions.stream() + .filter(function -> { + // check if all kernel functions are enabled + if (enabledKernelFunctions.isAllKernelFunctionsAllowed()) { + return true; + } + // otherwise, check for the specific function + return enabledKernelFunctions.isFunctionAllowed(function.getPluginName(), + function.getName()); + }) + .map(OpenAIFunction::getFunctionDefinition) + .map(it -> new ChatCompletionsFunctionToolDefinitionFunction(it.getName()) + .setDescription(it.getDescription()) + .setParameters(it.getParameters())) + .map(ChatCompletionsFunctionToolDefinition::new) + .collect(Collectors.toList()); - if (toolDefinitions.isEmpty()) { - return; + if (toolDefinitions.isEmpty()) { + return null; + } } - options.setTools(toolDefinitions); - options.setToolChoice( - new ChatCompletionsToolSelection(ChatCompletionsToolSelectionPreset.AUTO)); + return new OpenAIToolCallConfig( + toolDefinitions, + toolChoice, + toolCallBehavior.isAutoInvokeAllowed() + && requestIndex < Math.min(MAXIMUM_INFLIGHT_AUTO_INVOKES, + toolCallBehavior.getMaximumAutoInvokeAttempts()), + null); } private static boolean hasToolCallBeenExecuted(List chatRequestMessages, diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIFunction.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIFunction.java index e1f2f249..cf126d09 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIFunction.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIFunction.java @@ -164,7 +164,7 @@ private static String getSchemaForFunctionParameter(@Nullable InputVariable para entries.add("\"type\":\"" + type + "\""); // Add description if present - String description =null; + String description = null; if (parameter != null && parameter.getDescription() != null && !parameter.getDescription() .isEmpty()) { description = parameter.getDescription(); @@ -173,7 +173,7 @@ private static String getSchemaForFunctionParameter(@Nullable InputVariable para entries.add(String.format("\"description\":\"%s\"", description)); } // If custom type, generate schema - if("object".equalsIgnoreCase(type)) { + if ("object".equalsIgnoreCase(type)) { return getObjectSchema(parameter.getType(), description); } @@ -228,17 +228,17 @@ private static String getJavaTypeToOpenAiFunctionType(String javaType) { } } - private static String getObjectSchema(String type, String description){ - String schema= "{ \"type\" : \"object\" }"; + private static String getObjectSchema(String type, String description) { + String schema = "{ \"type\" : \"object\" }"; try { - Class clazz = Class.forName(type); - schema = ResponseSchemaGenerator.jacksonGenerator().generateSchema(clazz); + Class clazz = Class.forName(type); + schema = ResponseSchemaGenerator.jacksonGenerator().generateSchema(clazz); } catch (ClassNotFoundException | SKException ignored) { } Map properties = BinaryData.fromString(schema).toObject(Map.class); - if(StringUtils.isNotBlank(description)) { + if (StringUtils.isNotBlank(description)) { properties.put("description", description); } return BinaryData.fromObject(properties).toString(); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIToolCallConfig.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIToolCallConfig.java new file mode 100644 index 00000000..454ed3ce --- /dev/null +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIToolCallConfig.java @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.openai.chatcompletion; + +import com.azure.ai.openai.models.ChatCompletionsToolDefinition; +import com.azure.ai.openai.models.ChatCompletionsToolSelection; +import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehaviorOptions; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.List; + +public class OpenAIToolCallConfig { + private final List tools; + private final ChatCompletionsToolSelection toolChoice; + private final boolean autoInvoke; + @Nullable + private final FunctionChoiceBehaviorOptions options; + + /** + * Creates a new instance of the {@link OpenAIToolCallConfig} class. + * + * @param tools The list of tools available for the call. + * @param toolChoice The tool selection strategy. + * @param autoInvoke Indicates whether to automatically invoke the tool. + * @param options Additional options for function choice behavior. + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") + public OpenAIToolCallConfig( + List tools, + ChatCompletionsToolSelection toolChoice, + boolean autoInvoke, + @Nullable FunctionChoiceBehaviorOptions options) { + this.tools = tools; + this.toolChoice = toolChoice; + this.autoInvoke = autoInvoke; + this.options = options; + } + + /** + * Gets the list of tools available for the call. + * + * @return The list of tools. + */ + public List getTools() { + return Collections.unmodifiableList(tools); + } + + /** + * Gets the tool selection strategy. + * + * @return The tool selection strategy. + */ + public ChatCompletionsToolSelection getToolChoice() { + return toolChoice; + } + + /** + * Indicates whether to automatically invoke the tool. + * + * @return True if auto-invocation is enabled; otherwise, false. + */ + public boolean isAutoInvoke() { + return autoInvoke; + } + + /** + * Gets additional options for function choice behavior. + * + * @return The function choice behavior options. + */ + public FunctionChoiceBehaviorOptions getOptions() { + return options; + } +} diff --git a/aiservices/openai/src/test/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/JsonSchemaTest.java b/aiservices/openai/src/test/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/JsonSchemaTest.java index 68c8dae8..33870fba 100644 --- a/aiservices/openai/src/test/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/JsonSchemaTest.java +++ b/aiservices/openai/src/test/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/JsonSchemaTest.java @@ -47,12 +47,12 @@ public void openAIFunctionTest() { testFunction.getMetadata(), plugin.getName()); - String parameters = "{\"type\":\"object\",\"required\":[\"person\",\"input\"],\"properties\":{\"input\":{\"type\":\"string\",\"description\":\"input string\"},\"person\":{\"type\":\"object\",\"properties\":{\"age\":{\"type\":\"integer\",\"description\":\"The age of the person.\"},\"name\":{\"type\":\"string\",\"description\":\"The name of the person.\"},\"title\":{\"type\":\"string\",\"enum\":[\"MS\",\"MRS\",\"MR\"],\"description\":\"The title of the person.\"}},\"required\":[\"age\",\"name\",\"title\"],\"additionalProperties\":false,\"description\":\"input person\"}}}"; - Assertions.assertEquals(parameters, openAIFunction.getFunctionDefinition().getParameters().toString()); + String parameters = "{\"type\":\"object\",\"required\":[\"person\",\"input\"],\"properties\":{\"input\":{\"type\":\"string\",\"description\":\"input string\"},\"person\":{\"type\":\"object\",\"properties\":{\"age\":{\"type\":\"integer\",\"description\":\"The age of the person.\"},\"name\":{\"type\":\"string\",\"description\":\"The name of the person.\"},\"title\":{\"type\":\"string\",\"enum\":[\"MS\",\"MRS\",\"MR\"],\"description\":\"The title of the person.\"}},\"required\":[\"age\",\"name\",\"title\"],\"additionalProperties\":false,\"description\":\"input person\"}}}"; + Assertions.assertEquals(parameters, + openAIFunction.getFunctionDefinition().getParameters().toString()); } - public static class TestPlugin { @DefineKernelFunction @@ -67,19 +67,16 @@ public Mono asyncTestFunction( return Mono.just(1); } - @DefineKernelFunction(returnType = "int", description = "test function description", - name = "asyncPersonFunction", returnDescription = "test return description") + @DefineKernelFunction(returnType = "int", description = "test function description", name = "asyncPersonFunction", returnDescription = "test return description") public Mono asyncPersonFunction( - @KernelFunctionParameter(name = "person",description = "input person", type = Person.class) Person person, + @KernelFunctionParameter(name = "person", description = "input person", type = Person.class) Person person, @KernelFunctionParameter(name = "input", description = "input string") String input) { return Mono.just(1); } } private static enum Title { - MS, - MRS, - MR + MS, MRS, MR } public static class Person { @@ -90,7 +87,6 @@ public static class Person { @JsonPropertyDescription("The title of the person.") private Title title; - public Person(String name, int age) { this.name = name; this.age = age; diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightModelTypeConverter.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightModelTypeConverter.java index 250de12e..2752eb62 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightModelTypeConverter.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightModelTypeConverter.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.samples.demos.lights; import com.google.gson.Gson; @@ -10,14 +11,13 @@ public LightModelTypeConverter() { super( LightModel.class, obj -> { - if(obj instanceof String) { - return gson.fromJson((String)obj, LightModel.class); + if (obj instanceof String) { + return gson.fromJson((String) obj, LightModel.class); } else { return gson.fromJson(gson.toJson(obj), LightModel.class); } }, (types, lightModel) -> gson.toJson(lightModel), - json -> gson.fromJson(json, LightModel.class) - ); + json -> gson.fromJson(json, LightModel.class)); } } diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightsPlugin.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightsPlugin.java index fa11addb..398a8d16 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightsPlugin.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/LightsPlugin.java @@ -27,7 +27,7 @@ public List getLights() { @DefineKernelFunction(name = "add_light", description = "Adds a new light") public String addLight( @KernelFunctionParameter(name = "newLight", description = "new Light Details", type = LightModel.class) LightModel light) { - if( light != null) { + if (light != null) { System.out.println("Adding light " + light.getName()); lights.add(light); return "Light added"; diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubModel.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubModel.java index 180ec8ed..0f8065ee 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubModel.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubModel.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.samples.plugins.github; import com.fasterxml.jackson.annotation.JsonCreator; @@ -8,7 +9,7 @@ public abstract class GitHubModel { public final static ObjectMapper objectMapper = new ObjectMapper() - .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); @Override public String toString() { @@ -30,12 +31,13 @@ public static class User extends GitHubModel { private String company; @JsonProperty("html_url") private String url; + @JsonCreator public User(@JsonProperty("login") String login, - @JsonProperty("id") long id, - @JsonProperty("name") String name, - @JsonProperty("company") String company, - @JsonProperty("html_url") String url) { + @JsonProperty("id") long id, + @JsonProperty("name") String name, + @JsonProperty("company") String company, + @JsonProperty("html_url") String url) { this.login = login; this.id = id; this.name = name; @@ -46,15 +48,19 @@ public User(@JsonProperty("login") String login, public String getLogin() { return login; } + public long getId() { return id; } + public String getName() { return name; } + public String getCompany() { return company; } + public String getUrl() { return url; } @@ -69,11 +75,12 @@ public static class Repository extends GitHubModel { private String description; @JsonProperty("html_url") private String url; + @JsonCreator public Repository(@JsonProperty("id") long id, - @JsonProperty("full_name") String name, - @JsonProperty("description") String description, - @JsonProperty("html_url") String url) { + @JsonProperty("full_name") String name, + @JsonProperty("description") String description, + @JsonProperty("html_url") String url) { this.id = id; this.name = name; this.description = description; @@ -83,12 +90,15 @@ public Repository(@JsonProperty("id") long id, public long getId() { return id; } + public String getName() { return name; } + public String getDescription() { return description; } + public String getUrl() { return url; } @@ -123,13 +133,13 @@ public static class Issue extends GitHubModel { @JsonCreator public Issue(@JsonProperty("id") long id, - @JsonProperty("number") long number, - @JsonProperty("title") String title, - @JsonProperty("state") String state, - @JsonProperty("html_url") String url, - @JsonProperty("labels") Label[] labels, - @JsonProperty("created_at") String createdAt, - @JsonProperty("closed_at") String closedAt) { + @JsonProperty("number") long number, + @JsonProperty("title") String title, + @JsonProperty("state") String state, + @JsonProperty("html_url") String url, + @JsonProperty("labels") Label[] labels, + @JsonProperty("created_at") String createdAt, + @JsonProperty("closed_at") String closedAt) { this.id = id; this.number = number; this.title = title; @@ -143,24 +153,31 @@ public Issue(@JsonProperty("id") long id, public long getId() { return id; } + public long getNumber() { return number; } + public String getTitle() { return title; } + public String getState() { return state; } + public String getUrl() { return url; } + public Label[] getLabels() { return labels; } + public String getCreatedAt() { return createdAt; } + public String getClosedAt() { return closedAt; } @@ -172,14 +189,14 @@ public static class IssueDetail extends Issue { @JsonCreator public IssueDetail(@JsonProperty("id") long id, - @JsonProperty("number") long number, - @JsonProperty("title") String title, - @JsonProperty("state") String state, - @JsonProperty("html_url") String url, - @JsonProperty("labels") Label[] labels, - @JsonProperty("created_at") String createdAt, - @JsonProperty("closed_at") String closedAt, - @JsonProperty("body") String body) { + @JsonProperty("number") long number, + @JsonProperty("title") String title, + @JsonProperty("state") String state, + @JsonProperty("html_url") String url, + @JsonProperty("labels") Label[] labels, + @JsonProperty("created_at") String createdAt, + @JsonProperty("closed_at") String closedAt, + @JsonProperty("body") String body) { super(id, number, title, state, url, labels, createdAt, closedAt); this.body = body; } @@ -199,8 +216,8 @@ public static class Label extends GitHubModel { @JsonCreator public Label(@JsonProperty("id") long id, - @JsonProperty("name") String name, - @JsonProperty("description") String description) { + @JsonProperty("name") String name, + @JsonProperty("description") String description) { this.id = id; this.name = name; this.description = description; @@ -209,9 +226,11 @@ public Label(@JsonProperty("id") long id, public long getId() { return id; } + public String getName() { return name; } + public String getDescription() { return description; } diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubPlugin.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubPlugin.java index d3c59a15..f0bddee1 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubPlugin.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubPlugin.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.samples.plugins.github; import reactor.core.publisher.Mono; @@ -12,79 +13,47 @@ public class GitHubPlugin { public static final String baseUrl = "https://api.github.com"; private final String token; - public GitHubPlugin(String token) { + public GitHubPlugin(String token) { this.token = token; } - @DefineKernelFunction(name = "get_user_info", description = "Get user information from GitHub", - returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$User") + @DefineKernelFunction(name = "get_user_info", description = "Get user information from GitHub", returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$User") public Mono getUserProfileAsync() { HttpClient client = createClient(); return makeRequestAsync(client, "/user") - .map(json -> { - try { - return GitHubModel.objectMapper.readValue(json, GitHubModel.User.class); - } catch (IOException e) { - throw new IllegalStateException("Failed to deserialize GitHubUser", e); - } - }); + .map(json -> { + try { + return GitHubModel.objectMapper.readValue(json, GitHubModel.User.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubUser", e); + } + }); } - @DefineKernelFunction(name = "get_repo_info", description = "Get repository information from GitHub", - returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$Repository") + @DefineKernelFunction(name = "get_repo_info", description = "Get repository information from GitHub", returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$Repository") public Mono getRepositoryAsync( - @KernelFunctionParameter( - name = "organization", - description = "The name of the repository to retrieve information for" - ) String organization, - @KernelFunctionParameter( - name = "repo_name", - description = "The name of the repository to retrieve information for" - ) String repoName - ) { + @KernelFunctionParameter(name = "organization", description = "The name of the repository to retrieve information for") String organization, + @KernelFunctionParameter(name = "repo_name", description = "The name of the repository to retrieve information for") String repoName) { HttpClient client = createClient(); return makeRequestAsync(client, String.format("/repos/%s/%s", organization, repoName)) - .map(json -> { - try { - return GitHubModel.objectMapper.readValue(json, GitHubModel.Repository.class); - } catch (IOException e) { - throw new IllegalStateException("Failed to deserialize GitHubRepository", e); - } - }); + .map(json -> { + try { + return GitHubModel.objectMapper.readValue(json, GitHubModel.Repository.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubRepository", e); + } + }); } - @DefineKernelFunction(name = "get_issues", description = "Get issues from GitHub", - returnType = "java.util.List") + @DefineKernelFunction(name = "get_issues", description = "Get issues from GitHub", returnType = "java.util.List") public Mono> getIssuesAsync( - @KernelFunctionParameter( - name = "organization", - description = "The name of the organization to retrieve issues for" - ) String organization, - @KernelFunctionParameter( - name = "repo_name", - description = "The name of the repository to retrieve issues for" - ) String repoName, - @KernelFunctionParameter( - name = "max_results", - description = "The maximum number of issues to retrieve", - required = false, - defaultValue = "10", - type = int.class - ) int maxResults, - @KernelFunctionParameter( - name = "state", - description = "The state of the issues to retrieve", - required = false, - defaultValue = "open" - ) String state, - @KernelFunctionParameter( - name = "assignee", - description = "The assignee of the issues to retrieve", - required = false - ) String assignee - ) { + @KernelFunctionParameter(name = "organization", description = "The name of the organization to retrieve issues for") String organization, + @KernelFunctionParameter(name = "repo_name", description = "The name of the repository to retrieve issues for") String repoName, + @KernelFunctionParameter(name = "max_results", description = "The maximum number of issues to retrieve", required = false, defaultValue = "10", type = int.class) int maxResults, + @KernelFunctionParameter(name = "state", description = "The state of the issues to retrieve", required = false, defaultValue = "open") String state, + @KernelFunctionParameter(name = "assignee", description = "The assignee of the issues to retrieve", required = false) String assignee) { HttpClient client = createClient(); String query = String.format("/repos/%s/%s/issues", organization, repoName); @@ -93,58 +62,49 @@ public Mono> getIssuesAsync( query = buildQueryString(query, "per_page", String.valueOf(maxResults)); return makeRequestAsync(client, query) - .flatMap(json -> { - try { - GitHubModel.Issue[] issues = GitHubModel.objectMapper.readValue(json, GitHubModel.Issue[].class); - return Mono.just(List.of(issues)); - } catch (IOException e) { - throw new IllegalStateException("Failed to deserialize GitHubIssues", e); - } - }); + .flatMap(json -> { + try { + GitHubModel.Issue[] issues = GitHubModel.objectMapper.readValue(json, + GitHubModel.Issue[].class); + return Mono.just(List.of(issues)); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubIssues", e); + } + }); } - @DefineKernelFunction(name = "get_issue_detail_info", description = "Get detail information of a single issue from GitHub", - returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$IssueDetail") + @DefineKernelFunction(name = "get_issue_detail_info", description = "Get detail information of a single issue from GitHub", returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$IssueDetail") public GitHubModel.IssueDetail getIssueDetailAsync( - @KernelFunctionParameter( - name = "organization", - description = "The name of the repository to retrieve information for" - ) String organization, - @KernelFunctionParameter( - name = "repo_name", - description = "The name of the repository to retrieve information for" - ) String repoName, - @KernelFunctionParameter( - name = "issue_number", - description = "The issue number to retrieve information for", - type = int.class - ) int issueNumber - ) { + @KernelFunctionParameter(name = "organization", description = "The name of the repository to retrieve information for") String organization, + @KernelFunctionParameter(name = "repo_name", description = "The name of the repository to retrieve information for") String repoName, + @KernelFunctionParameter(name = "issue_number", description = "The issue number to retrieve information for", type = int.class) int issueNumber) { HttpClient client = createClient(); - return makeRequestAsync(client, String.format("/repos/%s/%s/issues/%d", organization, repoName, issueNumber)) - .map(json -> { - try { - return GitHubModel.objectMapper.readValue(json, GitHubModel.IssueDetail.class); - } catch (IOException e) { - throw new IllegalStateException("Failed to deserialize GitHubIssue", e); - } - }).block(); + return makeRequestAsync(client, + String.format("/repos/%s/%s/issues/%d", organization, repoName, issueNumber)) + .map(json -> { + try { + return GitHubModel.objectMapper.readValue(json, GitHubModel.IssueDetail.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubIssue", e); + } + }).block(); } private HttpClient createClient() { return HttpClient.create() - .baseUrl(baseUrl) - .headers(headers -> { - headers.add("User-Agent", "request"); - headers.add("Accept", "application/vnd.github+json"); - headers.add("Authorization", "Bearer " + token); - headers.add("X-GitHub-Api-Version", "2022-11-28"); - }); + .baseUrl(baseUrl) + .headers(headers -> { + headers.add("User-Agent", "request"); + headers.add("Accept", "application/vnd.github+json"); + headers.add("Authorization", "Bearer " + token); + headers.add("X-GitHub-Api-Version", "2022-11-28"); + }); } private static String buildQueryString(String path, String param, String value) { - if (value == null || value.isEmpty() || value.equals(KernelFunctionParameter.NO_DEFAULT_VALUE)) { + if (value == null || value.isEmpty() + || value.equals(KernelFunctionParameter.NO_DEFAULT_VALUE)) { return path; } @@ -153,13 +113,13 @@ private static String buildQueryString(String path, String param, String value) private Mono makeRequestAsync(HttpClient client, String path) { return client - .get() - .uri(path) - .responseSingle((res, content) -> { - if (res.status().code() != 200) { - return Mono.error(new IllegalStateException("Request failed: " + res.status())); - } - return content.asString(); - }); + .get() + .uri(path) + .responseSingle((res, content) -> { + if (res.status().code() != 200) { + return Mono.error(new IllegalStateException("Request failed: " + res.status())); + } + return content.asString(); + }); } } diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java index 1e5a665a..336406a3 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/agents/CompletionAgent.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.samples.syntaxexamples.agents; import com.azure.ai.openai.OpenAIAsyncClient; @@ -6,11 +7,13 @@ import com.azure.core.credential.KeyCredential; import com.microsoft.semantickernel.Kernel; import com.microsoft.semantickernel.agents.AgentInvokeOptions; +import com.microsoft.semantickernel.agents.AgentThread; import com.microsoft.semantickernel.agents.chatcompletion.ChatCompletionAgent; import com.microsoft.semantickernel.agents.chatcompletion.ChatHistoryAgentThread; import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion; import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter; import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; +import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehavior; import com.microsoft.semantickernel.implementation.templateengine.tokenizer.DefaultPromptTemplate; import com.microsoft.semantickernel.orchestration.InvocationContext; import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; @@ -34,9 +37,10 @@ public class CompletionAgent { // Only required if AZURE_CLIENT_KEY is set private static final String CLIENT_ENDPOINT = System.getenv("CLIENT_ENDPOINT"); private static final String MODEL_ID = System.getenv() - .getOrDefault("MODEL_ID", "gpt-4o"); + .getOrDefault("MODEL_ID", "gpt-4o"); private static final String GITHUB_PAT = System.getenv("GITHUB_PAT"); + public static void main(String[] args) { System.out.println("======== ChatCompletion Agent ========"); @@ -44,69 +48,65 @@ public static void main(String[] args) { if (AZURE_CLIENT_KEY != null) { client = new OpenAIClientBuilder() - .credential(new AzureKeyCredential(AZURE_CLIENT_KEY)) - .endpoint(CLIENT_ENDPOINT) - .buildAsyncClient(); + .credential(new AzureKeyCredential(AZURE_CLIENT_KEY)) + .endpoint(CLIENT_ENDPOINT) + .buildAsyncClient(); } else { client = new OpenAIClientBuilder() - .credential(new KeyCredential(CLIENT_KEY)) - .buildAsyncClient(); + .credential(new KeyCredential(CLIENT_KEY)) + .buildAsyncClient(); } System.out.println("------------------------"); ChatCompletionService chatCompletion = OpenAIChatCompletion.builder() - .withModelId(MODEL_ID) - .withOpenAIAsyncClient(client) - .build(); + .withModelId(MODEL_ID) + .withOpenAIAsyncClient(client) + .build(); Kernel kernel = Kernel.builder() - .withAIService(ChatCompletionService.class, chatCompletion) - .withPlugin(KernelPluginFactory.createFromObject(new GitHubPlugin(GITHUB_PAT), - "GitHubPlugin")) - .build(); + .withAIService(ChatCompletionService.class, chatCompletion) + .withPlugin(KernelPluginFactory.createFromObject(new GitHubPlugin(GITHUB_PAT), + "GitHubPlugin")) + .build(); InvocationContext invocationContext = InvocationContext.builder() - .withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true)) - .withContextVariableConverter(new ContextVariableTypeConverter<>( - GitHubModel.Issue.class, - o -> (GitHubModel.Issue) o, - o -> o.toString(), - s -> null - )) - .build(); + .withFunctionChoiceBehavior(FunctionChoiceBehavior.auto(true)) + .withContextVariableConverter(new ContextVariableTypeConverter<>( + GitHubModel.Issue.class, + o -> (GitHubModel.Issue) o, + o -> o.toString(), + s -> null)) + .build(); ChatCompletionAgent agent = ChatCompletionAgent.builder() - .withKernel(kernel) - .withKernelArguments( - KernelArguments.builder() - .withVariable("repository", "microsoft/semantic-kernel-java") - .withExecutionSettings(PromptExecutionSettings.builder() - .build()) - .build() - ) - .withInvocationContext(invocationContext) - .withTemplate( - DefaultPromptTemplate.build( - PromptTemplateConfig.builder() - .withTemplate( - """ + .withKernel(kernel) + .withKernelArguments( + KernelArguments.builder() + .withVariable("repository", "microsoft/semantic-kernel-java") + .withExecutionSettings(PromptExecutionSettings.builder() + .build()) + .build()) + .withInvocationContext(invocationContext) + .withTemplate( + DefaultPromptTemplate.build( + PromptTemplateConfig.builder() + .withTemplate( + """ You are an agent designed to query and retrieve information from a single GitHub repository in a read-only manner. You are also able to access the profile of the active user. - + Use the current date and time to provide up-to-date details or time-sensitive responses. - + The repository you are querying is a public repository with the following name: {{$repository}} - + The current date and time is: {{$now}}. - """ - ) - .build() - ) - ).build(); + """) + .build())) + .build(); - ChatHistoryAgentThread agentThread = new ChatHistoryAgentThread(); + AgentThread agentThread = new ChatHistoryAgentThread(); Scanner scanner = new Scanner(System.in); while (true) { @@ -119,22 +119,19 @@ public static void main(String[] args) { var message = new ChatMessageContent<>(AuthorRole.USER, input); KernelArguments arguments = KernelArguments.builder() - .withVariable("now", System.currentTimeMillis()) - .build(); + .withVariable("now", System.currentTimeMillis()) + .build(); var response = agent.invokeAsync( - List.of(message), - agentThread, - AgentInvokeOptions.builder() - .withKernel(kernel) - .withKernelArguments(arguments) - .build() - ).block(); - - var lastResponse = response.get(response.size() - 1); - - System.out.println("> " + lastResponse.getMessage()); - agentThread = (ChatHistoryAgentThread) lastResponse.getThread(); + message, + agentThread, + AgentInvokeOptions.builder() + .withKernelArguments(arguments) + .build()) + .block().get(0); + + System.out.println("> " + response.getMessage()); + agentThread = response.getThread(); } } } diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/functions/Example59_OpenAIFunctionCalling.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/functions/Example59_OpenAIFunctionCalling.java index e52c2be7..e921bb78 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/functions/Example59_OpenAIFunctionCalling.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/functions/Example59_OpenAIFunctionCalling.java @@ -10,6 +10,7 @@ import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatMessageContent; import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIFunctionToolCall; import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; +import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehavior; import com.microsoft.semantickernel.implementation.CollectionUtil; import com.microsoft.semantickernel.orchestration.FunctionResult; import com.microsoft.semantickernel.orchestration.FunctionResultMetadata; @@ -38,7 +39,7 @@ public class Example59_OpenAIFunctionCalling { // Only required if AZURE_CLIENT_KEY is set private static final String CLIENT_ENDPOINT = System.getenv("CLIENT_ENDPOINT"); private static final String MODEL_ID = System.getenv() - .getOrDefault("MODEL_ID", "gpt-35-turbo-2"); + .getOrDefault("MODEL_ID", "gpt-4o"); // Define functions that can be called by the model public static class HelperFunctions { @@ -118,7 +119,7 @@ public static void main(String[] args) throws NoSuchMethodException { var result = kernel .invokeAsync(function) - .withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true)) + .withFunctionChoiceBehavior(FunctionChoiceBehavior.auto(true)) .withResultType(ContextVariableTypes.getGlobalVariableTypeForClass(String.class)) .block(); System.out.println(result.getResult()); @@ -134,7 +135,7 @@ public static void main(String[] args) throws NoSuchMethodException { chatHistory, kernel, InvocationContext.builder() - .withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(false)) + .withFunctionChoiceBehavior(FunctionChoiceBehavior.auto(false)) .withReturnMode(InvocationReturnMode.FULL_HISTORY) .build()) .block(); @@ -243,7 +244,7 @@ public static void multiTurnaroundCall() { chatHistory, kernel, InvocationContext.builder() - .withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true)) + .withFunctionChoiceBehavior(FunctionChoiceBehavior.auto(true)) .withReturnMode(InvocationReturnMode.FULL_HISTORY) .build()) .block(); @@ -258,7 +259,7 @@ public static void multiTurnaroundCall() { chatHistory, kernel, InvocationContext.builder() - .withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true)) + .withFunctionChoiceBehavior(FunctionChoiceBehavior.auto(true)) .withReturnMode(InvocationReturnMode.FULL_HISTORY) .build()) .block(); diff --git a/samples/semantickernel-learn-resources/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubModel.java b/samples/semantickernel-learn-resources/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubModel.java new file mode 100644 index 00000000..0f8065ee --- /dev/null +++ b/samples/semantickernel-learn-resources/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubModel.java @@ -0,0 +1,238 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.samples.plugins.github; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; + +public abstract class GitHubModel { + public final static ObjectMapper objectMapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + + @Override + public String toString() { + try { + return objectMapper.writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public static class User extends GitHubModel { + @JsonProperty("login") + private String login; + @JsonProperty("id") + private long id; + @JsonProperty("name") + private String name; + @JsonProperty("company") + private String company; + @JsonProperty("html_url") + private String url; + + @JsonCreator + public User(@JsonProperty("login") String login, + @JsonProperty("id") long id, + @JsonProperty("name") String name, + @JsonProperty("company") String company, + @JsonProperty("html_url") String url) { + this.login = login; + this.id = id; + this.name = name; + this.company = company; + this.url = url; + } + + public String getLogin() { + return login; + } + + public long getId() { + return id; + } + + public String getName() { + return name; + } + + public String getCompany() { + return company; + } + + public String getUrl() { + return url; + } + } + + public static class Repository extends GitHubModel { + @JsonProperty("id") + private long id; + @JsonProperty("full_name") + private String name; + @JsonProperty("description") + private String description; + @JsonProperty("html_url") + private String url; + + @JsonCreator + public Repository(@JsonProperty("id") long id, + @JsonProperty("full_name") String name, + @JsonProperty("description") String description, + @JsonProperty("html_url") String url) { + this.id = id; + this.name = name; + this.description = description; + this.url = url; + } + + public long getId() { + return id; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public String getUrl() { + return url; + } + + @Override + public String toString() { + try { + return objectMapper.writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + public static class Issue extends GitHubModel { + @JsonProperty("id") + private long id; + @JsonProperty("number") + private long number; + @JsonProperty("title") + private String title; + @JsonProperty("state") + private String state; + @JsonProperty("html_url") + private String url; + @JsonProperty("labels") + private Label[] labels; + @JsonProperty("created_at") + private String createdAt; + @JsonProperty("closed_at") + private String closedAt; + + @JsonCreator + public Issue(@JsonProperty("id") long id, + @JsonProperty("number") long number, + @JsonProperty("title") String title, + @JsonProperty("state") String state, + @JsonProperty("html_url") String url, + @JsonProperty("labels") Label[] labels, + @JsonProperty("created_at") String createdAt, + @JsonProperty("closed_at") String closedAt) { + this.id = id; + this.number = number; + this.title = title; + this.state = state; + this.url = url; + this.labels = labels; + this.createdAt = createdAt; + this.closedAt = closedAt; + } + + public long getId() { + return id; + } + + public long getNumber() { + return number; + } + + public String getTitle() { + return title; + } + + public String getState() { + return state; + } + + public String getUrl() { + return url; + } + + public Label[] getLabels() { + return labels; + } + + public String getCreatedAt() { + return createdAt; + } + + public String getClosedAt() { + return closedAt; + } + } + + public static class IssueDetail extends Issue { + @JsonProperty("body") + private String body; + + @JsonCreator + public IssueDetail(@JsonProperty("id") long id, + @JsonProperty("number") long number, + @JsonProperty("title") String title, + @JsonProperty("state") String state, + @JsonProperty("html_url") String url, + @JsonProperty("labels") Label[] labels, + @JsonProperty("created_at") String createdAt, + @JsonProperty("closed_at") String closedAt, + @JsonProperty("body") String body) { + super(id, number, title, state, url, labels, createdAt, closedAt); + this.body = body; + } + + public String getBody() { + return body; + } + } + + public static class Label extends GitHubModel { + @JsonProperty("id") + private long id; + @JsonProperty("name") + private String name; + @JsonProperty("description") + private String description; + + @JsonCreator + public Label(@JsonProperty("id") long id, + @JsonProperty("name") String name, + @JsonProperty("description") String description) { + this.id = id; + this.name = name; + this.description = description; + } + + public long getId() { + return id; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + } +} diff --git a/samples/semantickernel-learn-resources/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubPlugin.java b/samples/semantickernel-learn-resources/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubPlugin.java new file mode 100644 index 00000000..f0bddee1 --- /dev/null +++ b/samples/semantickernel-learn-resources/src/main/java/com/microsoft/semantickernel/samples/plugins/github/GitHubPlugin.java @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.samples.plugins.github; + +import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; +import com.microsoft.semantickernel.semanticfunctions.annotations.DefineKernelFunction; +import com.microsoft.semantickernel.semanticfunctions.annotations.KernelFunctionParameter; + +import java.io.IOException; +import java.util.List; + +public class GitHubPlugin { + public static final String baseUrl = "https://api.github.com"; + private final String token; + + public GitHubPlugin(String token) { + this.token = token; + } + + @DefineKernelFunction(name = "get_user_info", description = "Get user information from GitHub", returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$User") + public Mono getUserProfileAsync() { + HttpClient client = createClient(); + + return makeRequestAsync(client, "/user") + .map(json -> { + try { + return GitHubModel.objectMapper.readValue(json, GitHubModel.User.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubUser", e); + } + }); + } + + @DefineKernelFunction(name = "get_repo_info", description = "Get repository information from GitHub", returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$Repository") + public Mono getRepositoryAsync( + @KernelFunctionParameter(name = "organization", description = "The name of the repository to retrieve information for") String organization, + @KernelFunctionParameter(name = "repo_name", description = "The name of the repository to retrieve information for") String repoName) { + HttpClient client = createClient(); + + return makeRequestAsync(client, String.format("/repos/%s/%s", organization, repoName)) + .map(json -> { + try { + return GitHubModel.objectMapper.readValue(json, GitHubModel.Repository.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubRepository", e); + } + }); + } + + @DefineKernelFunction(name = "get_issues", description = "Get issues from GitHub", returnType = "java.util.List") + public Mono> getIssuesAsync( + @KernelFunctionParameter(name = "organization", description = "The name of the organization to retrieve issues for") String organization, + @KernelFunctionParameter(name = "repo_name", description = "The name of the repository to retrieve issues for") String repoName, + @KernelFunctionParameter(name = "max_results", description = "The maximum number of issues to retrieve", required = false, defaultValue = "10", type = int.class) int maxResults, + @KernelFunctionParameter(name = "state", description = "The state of the issues to retrieve", required = false, defaultValue = "open") String state, + @KernelFunctionParameter(name = "assignee", description = "The assignee of the issues to retrieve", required = false) String assignee) { + HttpClient client = createClient(); + + String query = String.format("/repos/%s/%s/issues", organization, repoName); + query = buildQueryString(query, "state", state); + query = buildQueryString(query, "assignee", assignee); + query = buildQueryString(query, "per_page", String.valueOf(maxResults)); + + return makeRequestAsync(client, query) + .flatMap(json -> { + try { + GitHubModel.Issue[] issues = GitHubModel.objectMapper.readValue(json, + GitHubModel.Issue[].class); + return Mono.just(List.of(issues)); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubIssues", e); + } + }); + } + + @DefineKernelFunction(name = "get_issue_detail_info", description = "Get detail information of a single issue from GitHub", returnType = "com.microsoft.semantickernel.samples.plugins.github.GitHubModel$IssueDetail") + public GitHubModel.IssueDetail getIssueDetailAsync( + @KernelFunctionParameter(name = "organization", description = "The name of the repository to retrieve information for") String organization, + @KernelFunctionParameter(name = "repo_name", description = "The name of the repository to retrieve information for") String repoName, + @KernelFunctionParameter(name = "issue_number", description = "The issue number to retrieve information for", type = int.class) int issueNumber) { + HttpClient client = createClient(); + + return makeRequestAsync(client, + String.format("/repos/%s/%s/issues/%d", organization, repoName, issueNumber)) + .map(json -> { + try { + return GitHubModel.objectMapper.readValue(json, GitHubModel.IssueDetail.class); + } catch (IOException e) { + throw new IllegalStateException("Failed to deserialize GitHubIssue", e); + } + }).block(); + } + + private HttpClient createClient() { + return HttpClient.create() + .baseUrl(baseUrl) + .headers(headers -> { + headers.add("User-Agent", "request"); + headers.add("Accept", "application/vnd.github+json"); + headers.add("Authorization", "Bearer " + token); + headers.add("X-GitHub-Api-Version", "2022-11-28"); + }); + } + + private static String buildQueryString(String path, String param, String value) { + if (value == null || value.isEmpty() + || value.equals(KernelFunctionParameter.NO_DEFAULT_VALUE)) { + return path; + } + + return path + (path.contains("?") ? "&" : "?") + param + "=" + value; + } + + private Mono makeRequestAsync(HttpClient client, String path) { + return client + .get() + .uri(path) + .responseSingle((res, content) -> { + if (res.status().code() != 200) { + return Mono.error(new IllegalStateException("Request failed: " + res.status())); + } + return content.asString(); + }); + } +} diff --git a/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/src/main/java/com/microsoft/semantickernel/samples/openapi/OpenAPIHttpRequestPlugin.java b/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/src/main/java/com/microsoft/semantickernel/samples/openapi/OpenAPIHttpRequestPlugin.java index 58b3e525..ebdc8449 100644 --- a/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/src/main/java/com/microsoft/semantickernel/samples/openapi/OpenAPIHttpRequestPlugin.java +++ b/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/src/main/java/com/microsoft/semantickernel/samples/openapi/OpenAPIHttpRequestPlugin.java @@ -142,7 +142,7 @@ private String buildQueryPath(KernelArguments arguments) { } private static String getRenderedParameter( - KernelArguments arguments, String name) { + KernelArguments arguments, String name) { ContextVariable value = arguments.get(name); if (value == null) { diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java index d8059ef7..5e6cc2fd 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java @@ -192,7 +192,7 @@ public FunctionInvocation invokePromptAsync(@Nonnull String prompt, */ public FunctionInvocation invokePromptAsync(@Nonnull String prompt, - @Nonnull KernelArguments arguments, @Nonnull InvocationContext invocationContext) { + @Nonnull KernelArguments arguments, @Nonnull InvocationContext invocationContext) { KernelFunction function = KernelFunction.createFromPrompt(prompt).build(); @@ -327,11 +327,12 @@ public T getService(Class clazz) throws ServiceNotFound * @throws ServiceNotFoundException if the service is not found. * @see com.microsoft.semantickernel.services.AIServiceSelector#trySelectAIService(Class, KernelArguments) */ - public T getService(Class clazz, KernelArguments args) throws ServiceNotFoundException { + public T getService(Class clazz, KernelArguments args) + throws ServiceNotFoundException { AIServiceSelection selector = serviceSelector - .trySelectAIService( - clazz, - args); + .trySelectAIService( + clazz, + args); if (selector == null) { throw new ServiceNotFoundException("Unable to find service of type " + clazz.getName()); diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java index 3a82550c..f69b2152 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/Agent.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents; import java.util.HashMap; @@ -17,6 +18,8 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import javax.annotation.Nullable; + /** * Interface for a semantic kernel agent. */ @@ -43,6 +46,39 @@ public interface Agent { */ String getDescription(); + /** + * Invokes the agent with the given message. + * + * @param message The message to process + * @return A Mono containing the agent response + */ + Mono>>> invokeAsync( + @Nullable ChatMessageContent message); + + /** + * Invokes the agent with the given message and thread. + * + * @param message The message to process + * @param thread The agent thread to use + * @return A Mono containing the agent response + */ + Mono>>> invokeAsync( + @Nullable ChatMessageContent message, + @Nullable AgentThread thread); + + /** + * Invokes the agent with the given message, thread, and options. + * + * @param message The message to process + * @param thread The agent thread to use + * @param options The options for invoking the agent + * @return A Mono containing the agent response + */ + Mono>>> invokeAsync( + @Nullable ChatMessageContent message, + @Nullable AgentThread thread, + @Nullable AgentInvokeOptions options); + /** * Invoke the agent with the given chat history. * @@ -51,7 +87,10 @@ public interface Agent { * @param options The options for invoking the agent * @return A Mono containing the agent response */ - Mono>>> invokeAsync(List> messages, AgentThread thread, AgentInvokeOptions options); + Mono>>> invokeAsync( + List> messages, + @Nullable AgentThread thread, + @Nullable AgentInvokeOptions options); /** * Notifies the agent of a new message. diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java index 3fb4cba1..6b6d57ed 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentInvokeOptions.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents; import com.microsoft.semantickernel.Kernel; @@ -13,9 +14,13 @@ */ public class AgentInvokeOptions { + @Nullable private final KernelArguments kernelArguments; + @Nullable private final Kernel kernel; + @Nullable private final String additionalInstructions; + @Nullable private final InvocationContext invocationContext; /** @@ -34,9 +39,9 @@ public AgentInvokeOptions() { * @param invocationContext The invocation context. */ public AgentInvokeOptions(@Nullable KernelArguments kernelArguments, - @Nullable Kernel kernel, - @Nullable String additionalInstructions, - @Nullable InvocationContext invocationContext) { + @Nullable Kernel kernel, + @Nullable String additionalInstructions, + @Nullable InvocationContext invocationContext) { this.kernelArguments = kernelArguments != null ? kernelArguments.copy() : null; this.kernel = kernel; this.additionalInstructions = additionalInstructions; @@ -80,8 +85,6 @@ public InvocationContext getInvocationContext() { return invocationContext; } - - /** * Builder for AgentInvokeOptions. */ @@ -152,8 +155,7 @@ public AgentInvokeOptions build() { kernelArguments, kernel, additionalInstructions, - invocationContext - ); + invocationContext); } } } \ No newline at end of file diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentResponseItem.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentResponseItem.java index f585bfea..0b455098 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentResponseItem.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentResponseItem.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentThread.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentThread.java index d369d999..94538f41 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentThread.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/AgentThread.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents; import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent; diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/BaseAgentThread.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/BaseAgentThread.java index b7c97eea..c66fe9b9 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/BaseAgentThread.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/BaseAgentThread.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents; public abstract class BaseAgentThread implements AgentThread { @@ -16,6 +17,7 @@ public BaseAgentThread(String id) { public String getId() { return id; } + @Override public boolean isDeleted() { return isDeleted; diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java index 71e81951..8403093e 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/agents/KernelAgent.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.agents; import com.microsoft.semantickernel.Kernel; @@ -11,6 +12,8 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -29,23 +32,24 @@ public abstract class KernelAgent implements Agent { protected final PromptTemplate template; protected KernelAgent( - String id, - String name, - String description, - Kernel kernel, - KernelArguments kernelArguments, - InvocationContext invocationContext, - String instructions, - PromptTemplate template - ) { + String id, + String name, + String description, + Kernel kernel, + KernelArguments kernelArguments, + InvocationContext invocationContext, + String instructions, + PromptTemplate template) { this.id = id != null ? id : UUID.randomUUID().toString(); this.name = name; this.description = description; this.kernel = kernel; this.kernelArguments = kernelArguments != null - ? kernelArguments.copy() : KernelArguments.builder().build(); + ? kernelArguments.copy() + : KernelArguments.builder().build(); this.invocationContext = invocationContext != null - ? invocationContext : InvocationContext.builder().build(); + ? invocationContext + : InvocationContext.builder().build(); this.instructions = instructions; this.template = template; } @@ -114,7 +118,6 @@ public PromptTemplate getTemplate() { return template; } - /** * Merges the provided arguments with the current arguments. * Provided arguments will override the current arguments. @@ -126,14 +129,15 @@ protected KernelArguments mergeArguments(KernelArguments arguments) { return kernelArguments; } - Map executionSettings = new HashMap<>(kernelArguments.getExecutionSettings()); + Map executionSettings = new HashMap<>( + kernelArguments.getExecutionSettings()); executionSettings.putAll(arguments.getExecutionSettings()); return KernelArguments.builder() - .withVariables(kernelArguments) - .withVariables(arguments) - .withExecutionSettings(executionSettings) - .build(); + .withVariables(kernelArguments) + .withVariables(arguments) + .withExecutionSettings(executionSettings) + .build(); } /** @@ -144,7 +148,8 @@ protected KernelArguments mergeArguments(KernelArguments arguments) { * @param context The context to use for formatting. * @return A Mono that resolves to the formatted instructions. */ - protected Mono renderInstructionsAsync(Kernel kernel, KernelArguments arguments, InvocationContext context) { + protected Mono renderInstructionsAsync(Kernel kernel, KernelArguments arguments, + InvocationContext context) { if (template != null) { return template.renderAsync(kernel, arguments, context); } else { @@ -152,19 +157,45 @@ protected Mono renderInstructionsAsync(Kernel kernel, KernelArguments ar } } - protected Mono ensureThreadExistsWithMessagesAsync(List> messages, AgentThread thread, Supplier threadSupplier) { + protected Mono ensureThreadExistsWithMessagesAsync( + List> messages, AgentThread thread, Supplier threadSupplier) { return Mono.defer(() -> { // Check if the thread already exists // If it does, we can work with a copy of it AgentThread newThread = thread == null ? threadSupplier.get() : thread.copy(); - return newThread.createAsync() - .thenMany(Flux.fromIterable(messages)) - .concatMap(message -> { - return notifyThreadOfNewMessageAsync(newThread, message) - .then(Mono.just(message)); - }) - .then(Mono.just((T) newThread)); + return newThread.createAsync() + .thenMany(Flux.fromIterable(messages)) + .concatMap(message -> { + return notifyThreadOfNewMessageAsync(newThread, message) + .then(Mono.just(message)); + }) + .then(Mono.just((T) newThread)); }); } + + @Override + public Mono>>> invokeAsync( + @Nullable ChatMessageContent message) { + return invokeAsync(message, null, null); + } + + @Override + public Mono>>> invokeAsync( + @Nullable ChatMessageContent message, + @Nullable AgentThread thread) { + return invokeAsync(message, thread, null); + } + + @Override + public Mono>>> invokeAsync( + @Nullable ChatMessageContent message, + @Nullable AgentThread thread, + @Nullable AgentInvokeOptions options) { + ArrayList> messages = new ArrayList<>(); + if (message != null) { + messages.add(message); + } + return invokeAsync(messages, thread, options); + } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/AutoFunctionChoiceBehavior.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/AutoFunctionChoiceBehavior.java new file mode 100644 index 00000000..b4993f44 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/AutoFunctionChoiceBehavior.java @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.functionchoice; + +import com.microsoft.semantickernel.semanticfunctions.KernelFunction; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * A set of allowed kernel functions. All kernel functions are allowed if allKernelFunctionsAllowed is true. + * Otherwise, only the functions in allowedFunctions are allowed. + *

+ * If a function is allowed, it may be called. If it is not allowed, it will not be called. + */ +public class AutoFunctionChoiceBehavior extends FunctionChoiceBehavior { + private final boolean autoInvoke; + + /** + * Create a new instance of AutoFunctionChoiceBehavior. + * + * @param autoInvoke Whether auto-invocation is enabled. + * @param functions A set of functions to advertise to the model. + * @param options Options for the function choice behavior. + */ + public AutoFunctionChoiceBehavior(boolean autoInvoke, + @Nullable List> functions, + @Nullable FunctionChoiceBehaviorOptions options) { + super(functions, options); + this.autoInvoke = autoInvoke; + } + + /** + * Check whether the given function is allowed. + * + * @return Whether the function is allowed. + */ + public boolean isAutoInvoke() { + return autoInvoke; + } +} \ No newline at end of file diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/FunctionChoiceBehavior.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/FunctionChoiceBehavior.java new file mode 100644 index 00000000..cbf64bde --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/FunctionChoiceBehavior.java @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.functionchoice; + +import com.microsoft.semantickernel.semanticfunctions.KernelFunction; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * Defines the behavior of a tool call. Currently, the only tool available is function calling. + */ +public abstract class FunctionChoiceBehavior { + private final Set fullFunctionNames; + + protected final List> functions; + protected final FunctionChoiceBehaviorOptions options; + + protected FunctionChoiceBehavior(@Nullable List> functions, + @Nullable FunctionChoiceBehaviorOptions options) { + this.functions = functions != null ? Collections.unmodifiableList(functions) : null; + this.fullFunctionNames = new HashSet<>(); + + if (functions != null) { + functions.stream().filter(Objects::nonNull).forEach( + f -> this.fullFunctionNames + .add(formFullFunctionName(f.getPluginName(), f.getName()))); + } + + if (options != null) { + this.options = options; + } else { + this.options = FunctionChoiceBehaviorOptions.builder().build(); + } + } + + /** + * Gets the functions that are allowed. + * + * @return The functions that are allowed. + */ + public List> getFunctions() { + return Collections.unmodifiableList(functions); + } + + /** + * Gets the options for the function choice behavior. + * + * @return The options for the function choice behavior. + */ + public FunctionChoiceBehaviorOptions getOptions() { + return options; + } + + /** + * Gets an instance of the FunctionChoiceBehavior that provides all the Kernel's plugins functions to the AI model to call. + * + * @param autoInvoke Indicates whether the functions should be automatically invoked by AI connectors + * + * @return A new ToolCallBehavior instance with all kernel functions allowed. + */ + public static FunctionChoiceBehavior auto(boolean autoInvoke) { + return new AutoFunctionChoiceBehavior(autoInvoke, null, null); + } + + /** + * Gets an instance of the FunctionChoiceBehavior that provides either all the Kernel's plugins functions to the AI model to call or specific functions. + * + * @param autoInvoke Enable or disable auto-invocation. + * If auto-invocation is enabled, the model may request that the Semantic Kernel + * invoke the kernel functions and return the value to the model. + * @param functions Functions to provide to the model. If null, all the Kernel's plugins' functions are provided to the model. + * If empty, no functions are provided to the model, which is equivalent to disabling function calling. + * @param options Options for the function choice behavior. + * + * @return A new FunctionChoiceBehavior instance with all kernel functions allowed. + */ + public static FunctionChoiceBehavior auto(boolean autoInvoke, + List> functions, + @Nullable FunctionChoiceBehaviorOptions options) { + return new AutoFunctionChoiceBehavior(autoInvoke, functions, options); + } + + /** + * Gets an instance of the FunctionChoiceBehavior that provides either all the Kernel's plugins functions to the AI model to call or specific functions. + *

+ * This behavior forces the model to call the provided functions. + * SK connectors will invoke a requested function or multiple requested functions if the model requests multiple ones in one request, + * while handling the first request, and stop advertising the functions for the following requests to prevent the model from repeatedly calling the same function(s). + * + * @param functions Functions to provide to the model. If null, all the Kernel's plugins' functions are provided to the model. + * If empty, no functions are provided to the model, which is equivalent to disabling function calling. + * @return A new FunctionChoiceBehavior instance with the required function. + */ + public static FunctionChoiceBehavior required(boolean autoInvoke, + List> functions, + @Nullable FunctionChoiceBehaviorOptions options) { + return new RequiredFunctionChoiceBehavior(autoInvoke, functions, options); + } + + /** + * Gets an instance of the FunctionChoiceBehavior that provides either all the Kernel's plugins functions to the AI model to call or specific functions. + *

+ * This behavior is useful if the user should first validate what functions the model will use. + * + * @param functions Functions to provide to the model. If null, all the Kernel's plugins' functions are provided to the model. + * If empty, no functions are provided to the model, which is equivalent to disabling function calling. + */ + public static FunctionChoiceBehavior none(List> functions, + @Nullable FunctionChoiceBehaviorOptions options) { + return new NoneFunctionChoiceBehavior(functions, options); + } + + /** + * The separator between the plugin name and the function name. + */ + public static final String FUNCTION_NAME_SEPARATOR = "-"; + + /** + * Form the full function name. + * + * @param pluginName The name of the plugin that the function is in. + * @param functionName The name of the function. + * @return The key for the function. + */ + public static String formFullFunctionName(@Nullable String pluginName, String functionName) { + if (pluginName == null) { + pluginName = ""; + } + return String.format("%s%s%s", pluginName, FUNCTION_NAME_SEPARATOR, functionName); + } + + /** + * Check whether the given function is allowed. + * + * @param function The function to check. + * @return Whether the function is allowed. + */ + public boolean isFunctionAllowed(KernelFunction function) { + return isFunctionAllowed(function.getPluginName(), function.getName()); + } + + /** + * Check whether the given function is allowed. + * + * @param pluginName The name of the plugin that the function is in. + * @param functionName The name of the function. + * @return Whether the function is allowed. + */ + public boolean isFunctionAllowed(@Nullable String pluginName, String functionName) { + // If no functions are provided, all functions are allowed. + if (functions == null || functions.isEmpty()) { + return true; + } + + String key = formFullFunctionName(pluginName, functionName); + return fullFunctionNames.contains(key); + } +} diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/FunctionChoiceBehaviorOptions.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/FunctionChoiceBehaviorOptions.java new file mode 100644 index 00000000..ffb17c78 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/FunctionChoiceBehaviorOptions.java @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.functionchoice; + +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; + +public class FunctionChoiceBehaviorOptions { + private final boolean parallelCallsAllowed; + + private FunctionChoiceBehaviorOptions(boolean parallelCallsAllowed) { + this.parallelCallsAllowed = parallelCallsAllowed; + } + + /** + * Returns a new builder for {@link FunctionChoiceBehaviorOptions}. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Indicates whether parallel calls to functions are allowed. + * + * @return True if parallel calls are allowed; otherwise, false. + */ + public boolean isParallelCallsAllowed() { + return parallelCallsAllowed; + } + + /** + * Builder for {@link FunctionChoiceBehaviorOptions}. + */ + public static class Builder implements SemanticKernelBuilder { + private boolean allowParallelCalls = false; + + /** + * Sets whether parallel calls to functions are allowed. + * + * @param allowParallelCalls True if parallel calls are allowed; otherwise, false. + * @return The builder instance. + */ + public Builder withParallelCallsAllowed(boolean allowParallelCalls) { + this.allowParallelCalls = allowParallelCalls; + return this; + } + + public FunctionChoiceBehaviorOptions build() { + return new FunctionChoiceBehaviorOptions(allowParallelCalls); + } + } +} diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/NoneFunctionChoiceBehavior.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/NoneFunctionChoiceBehavior.java new file mode 100644 index 00000000..e5bef9f4 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/NoneFunctionChoiceBehavior.java @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.functionchoice; + +import com.microsoft.semantickernel.semanticfunctions.KernelFunction; + +import java.util.List; + +public class NoneFunctionChoiceBehavior extends FunctionChoiceBehavior { + + /** + * Create a new instance of NoneFunctionChoiceBehavior. + */ + public NoneFunctionChoiceBehavior(List> functions, + FunctionChoiceBehaviorOptions options) { + super(functions, options); + } +} diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/RequiredFunctionChoiceBehavior.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/RequiredFunctionChoiceBehavior.java new file mode 100644 index 00000000..01710cca --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/functionchoice/RequiredFunctionChoiceBehavior.java @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.functionchoice; + +import com.microsoft.semantickernel.semanticfunctions.KernelFunction; + +import java.util.List; + +public class RequiredFunctionChoiceBehavior extends AutoFunctionChoiceBehavior { + + /** + * Create a new instance of RequiredFunctionChoiceBehavior. + * + * @param autoInvoke Whether auto-invocation is enabled. + * @param functions A set of functions to advertise to the model. + * @param options Options for the function choice behavior. + */ + public RequiredFunctionChoiceBehavior(boolean autoInvoke, List> functions, + FunctionChoiceBehaviorOptions options) { + super(autoInvoke, functions, options); + } +} diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java index 0ae16e19..9b8a518c 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java @@ -8,6 +8,7 @@ import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; import com.microsoft.semantickernel.contextvariables.converters.ContextVariableJacksonConverter; import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehavior; import com.microsoft.semantickernel.hooks.KernelHook; import com.microsoft.semantickernel.hooks.KernelHooks; import com.microsoft.semantickernel.hooks.KernelHooks.UnmodifiableKernelHooks; @@ -48,6 +49,9 @@ public class FunctionInvocation extends Mono> { protected PromptExecutionSettings promptExecutionSettings; @Nullable protected ToolCallBehavior toolCallBehavior; + @Nullable + protected FunctionChoiceBehavior functionChoiceBehavior; + @Nullable protected SemanticKernelTelemetry telemetry; @@ -196,6 +200,7 @@ public FunctionInvocation withResultType(ContextVariableType resultTyp .withArguments(arguments) .addKernelHooks(hooks) .withPromptExecutionSettings(promptExecutionSettings) + .withFunctionChoiceBehavior(functionChoiceBehavior) .withToolCallBehavior(toolCallBehavior) .withTypes(contextVariableTypes); } @@ -287,10 +292,32 @@ public FunctionInvocation withPromptExecutionSettings( */ public FunctionInvocation withToolCallBehavior(@Nullable ToolCallBehavior toolCallBehavior) { logSubscribeWarning(); + if (toolCallBehavior != null && functionChoiceBehavior != null) { + throw new SKException( + "ToolCallBehavior cannot be set when FunctionChoiceBehavior is set."); + } this.toolCallBehavior = toolCallBehavior; return this; } + /** + * Supply function choice behavior to the function invocation. + * + * @param functionChoiceBehavior The function choice behavior to supply to the function + * invocation. + * @return this {@code FunctionInvocation} for fluent chaining. + */ + public FunctionInvocation withFunctionChoiceBehavior( + @Nullable FunctionChoiceBehavior functionChoiceBehavior) { + if (functionChoiceBehavior != null && toolCallBehavior != null) { + throw new SKException( + "FunctionChoiceBehavior cannot be set when ToolCallBehavior is set."); + } + logSubscribeWarning(); + this.functionChoiceBehavior = functionChoiceBehavior; + return this; + } + /** * Supply a type converter to the function invocation. * @@ -340,6 +367,7 @@ public FunctionInvocation withInvocationContext( } logSubscribeWarning(); withTypes(invocationContext.getContextVariableTypes()); + withFunctionChoiceBehavior(invocationContext.getFunctionChoiceBehavior()); withToolCallBehavior(invocationContext.getToolCallBehavior()); withPromptExecutionSettings(invocationContext.getPromptExecutionSettings()); addKernelHooks(invocationContext.getKernelHooks()); @@ -387,6 +415,7 @@ public void subscribe(CoreSubscriber> coreSubscriber) hooks, promptExecutionSettings, toolCallBehavior, + functionChoiceBehavior, contextVariableTypes, InvocationReturnMode.NEW_MESSAGES_ONLY, telemetry)); diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/InvocationContext.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/InvocationContext.java index 6fd3f0d2..6a8547c4 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/InvocationContext.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/InvocationContext.java @@ -4,6 +4,8 @@ import com.microsoft.semantickernel.builders.SemanticKernelBuilder; import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter; import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; +import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.functionchoice.FunctionChoiceBehavior; import com.microsoft.semantickernel.hooks.KernelHooks; import com.microsoft.semantickernel.hooks.KernelHooks.UnmodifiableKernelHooks; import com.microsoft.semantickernel.implementation.telemetry.SemanticKernelTelemetry; @@ -23,6 +25,8 @@ public class InvocationContext { private final PromptExecutionSettings promptExecutionSettings; @Nullable private final ToolCallBehavior toolCallBehavior; + @Nullable + private final FunctionChoiceBehavior functionChoiceBehavior; private final ContextVariableTypes contextVariableTypes; private final InvocationReturnMode invocationReturnMode; private final SemanticKernelTelemetry telemetry; @@ -39,12 +43,14 @@ protected InvocationContext( @Nullable KernelHooks hooks, @Nullable PromptExecutionSettings promptExecutionSettings, @Nullable ToolCallBehavior toolCallBehavior, + @Nullable FunctionChoiceBehavior functionChoiceBehavior, @Nullable ContextVariableTypes contextVariableTypes, InvocationReturnMode invocationReturnMode, SemanticKernelTelemetry telemetry) { this.hooks = unmodifiableClone(hooks); this.promptExecutionSettings = promptExecutionSettings; this.toolCallBehavior = toolCallBehavior; + this.functionChoiceBehavior = functionChoiceBehavior; this.invocationReturnMode = invocationReturnMode; if (contextVariableTypes == null) { this.contextVariableTypes = new ContextVariableTypes(); @@ -61,6 +67,7 @@ protected InvocationContext() { this.hooks = null; this.promptExecutionSettings = null; this.toolCallBehavior = null; + this.functionChoiceBehavior = null; this.contextVariableTypes = new ContextVariableTypes(); this.invocationReturnMode = InvocationReturnMode.NEW_MESSAGES_ONLY; this.telemetry = null; @@ -76,6 +83,7 @@ protected InvocationContext(@Nullable InvocationContext context) { this.hooks = null; this.promptExecutionSettings = null; this.toolCallBehavior = null; + this.functionChoiceBehavior = null; this.contextVariableTypes = new ContextVariableTypes(); this.invocationReturnMode = InvocationReturnMode.NEW_MESSAGES_ONLY; this.telemetry = null; @@ -83,6 +91,7 @@ protected InvocationContext(@Nullable InvocationContext context) { this.hooks = context.hooks; this.promptExecutionSettings = context.promptExecutionSettings; this.toolCallBehavior = context.toolCallBehavior; + this.functionChoiceBehavior = context.functionChoiceBehavior; this.contextVariableTypes = context.contextVariableTypes; this.invocationReturnMode = context.invocationReturnMode; this.telemetry = context.telemetry; @@ -156,6 +165,16 @@ public ToolCallBehavior getToolCallBehavior() { return toolCallBehavior; } + /** + * Get the behavior for function choice. + * + * @return The behavior for function choice. + */ + @Nullable + public FunctionChoiceBehavior getFunctionChoiceBehavior() { + return functionChoiceBehavior; + } + /** * Get the types of context variables. * @@ -190,6 +209,8 @@ public static class Builder implements SemanticKernelBuilder private PromptExecutionSettings promptExecutionSettings; @Nullable private ToolCallBehavior toolCallBehavior; + @Nullable + private FunctionChoiceBehavior functionChoiceBehavior; private InvocationReturnMode invocationReturnMode = InvocationReturnMode.NEW_MESSAGES_ONLY; @Nullable private SemanticKernelTelemetry telemetry; @@ -226,10 +247,30 @@ public Builder withPromptExecutionSettings( */ public Builder withToolCallBehavior( @Nullable ToolCallBehavior toolCallBehavior) { + if (toolCallBehavior != null && functionChoiceBehavior != null) { + throw new SKException( + "ToolCallBehavior cannot be set when FunctionChoiceBehavior is set."); + } this.toolCallBehavior = toolCallBehavior; return this; } + /** + * Add function choice behavior to the builder. + * + * @param functionChoiceBehavior the behavior to add. + * @return this {@link Builder} + */ + public Builder withFunctionChoiceBehavior( + @Nullable FunctionChoiceBehavior functionChoiceBehavior) { + if (functionChoiceBehavior != null && toolCallBehavior != null) { + throw new SKException( + "FunctionChoiceBehavior cannot be set when ToolCallBehavior is set."); + } + this.functionChoiceBehavior = functionChoiceBehavior; + return this; + } + /** * Add a context variable type converter to the builder. * @@ -269,7 +310,7 @@ public Builder withReturnMode(InvocationReturnMode invocationReturnMode) { /** * Add a tracer to the builder. * - * @param tracer the tracer to add. + * @param telemetry the tracer to add. * @return this {@link Builder} */ public Builder withTelemetry(@Nullable SemanticKernelTelemetry telemetry) { @@ -283,6 +324,7 @@ public InvocationContext build() { telemetry = new SemanticKernelTelemetry(); } return new InvocationContext(hooks, promptExecutionSettings, toolCallBehavior, + functionChoiceBehavior, contextVariableTypes, invocationReturnMode, telemetry); } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelArguments.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelArguments.java index e4fb8319..bf9e6565 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelArguments.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelArguments.java @@ -42,8 +42,8 @@ public class KernelArguments implements Map> { * @param variables The variables to use for the function invocation. */ protected KernelArguments( - @Nullable Map> variables, - @Nullable Map executionSettings) { + @Nullable Map> variables, + @Nullable Map executionSettings) { if (variables == null) { this.variables = new CaseInsensitiveMap<>(); } else { @@ -112,14 +112,14 @@ public ContextVariable getInput() { */ public String prettyPrint() { return variables.entrySet().stream() - .reduce( - "", - (str, entry) -> str - + System.lineSeparator() - + entry.getKey() - + ": " - + entry.getValue().toPromptString(ContextVariableTypes.getGlobalTypes()), - (a, b) -> a + b); + .reduce( + "", + (str, entry) -> str + + System.lineSeparator() + + entry.getKey() + + ": " + + entry.getValue().toPromptString(ContextVariableTypes.getGlobalTypes()), + (a, b) -> a + b); } /** @@ -149,9 +149,9 @@ ContextVariable get(String key, Class clazz) { } throw new SKException( - String.format( - "Variable %s is of type %s, but requested type is %s", - key, value.getType().getClazz(), clazz)); + String.format( + "Variable %s is of type %s, but requested type is %s", + key, value.getType().getClazz(), clazz)); } /** @@ -243,7 +243,6 @@ public static Builder builder() { return new Builder<>(KernelArguments::new); } - /** * Builder for ContextVariables */ @@ -253,7 +252,6 @@ public static class Builder implements SemanticKernel private final Map> variables; private final Map executionSettings; - protected Builder(Function constructor) { this.constructor = constructor; this.variables = new HashMap<>(); @@ -293,10 +291,10 @@ public Builder withInput(Object content) { */ public Builder withInput(T content, ContextVariableTypeConverter typeConverter) { return withInput(new ContextVariable<>( - new ContextVariableType<>( - typeConverter, - typeConverter.getType()), - content)); + new ContextVariableType<>( + typeConverter, + typeConverter.getType()), + content)); } /** @@ -352,12 +350,12 @@ public Builder withVariable(String key, Object value) { * @throws SKException if the value cannot be converted to a ContextVariable */ public Builder withVariable(String key, T value, - ContextVariableTypeConverter typeConverter) { + ContextVariableTypeConverter typeConverter) { return withVariable(key, new ContextVariable<>( - new ContextVariableType<>( - typeConverter, - typeConverter.getType()), - value)); + new ContextVariableType<>( + typeConverter, + typeConverter.getType()), + value)); } /** @@ -376,8 +374,14 @@ public Builder withExecutionSettings(PromptExecutionSettings executionSetting * @param executionSettings Execution settings * @return {$code this} Builder for fluent coding */ - public Builder withExecutionSettings(Map executionSettings) { - return withExecutionSettings(new ArrayList<>(executionSettings.values())); + public Builder withExecutionSettings( + Map executionSettings) { + if (executionSettings == null) { + return this; + } + + this.executionSettings.putAll(executionSettings); + return this; } /** @@ -397,17 +401,15 @@ public Builder withExecutionSettings(List executionS if (this.executionSettings.containsKey(serviceId)) { if (serviceId.equals(PromptExecutionSettings.DEFAULT_SERVICE_ID)) { throw new SKException( - String.format( - "Multiple prompt execution settings with the default service id '%s' or no service id have been provided. Specify a single default prompt execution settings and provide a unique service id for all other instances.", - PromptExecutionSettings.DEFAULT_SERVICE_ID) - ); + String.format( + "Multiple prompt execution settings with the default service id '%s' or no service id have been provided. Specify a single default prompt execution settings and provide a unique service id for all other instances.", + PromptExecutionSettings.DEFAULT_SERVICE_ID)); } throw new SKException( - String.format( - "Multiple prompt execution settings with the service id '%s' have been provided. Specify a unique service id for all instances.", - serviceId) - ); + String.format( + "Multiple prompt execution settings with the service id '%s' have been provided. Specify a unique service id for all instances.", + serviceId)); } this.executionSettings.put(serviceId, settings); diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java index 2ea90974..642b5add 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java @@ -135,14 +135,15 @@ private Flux> invokeInternalAsync( .executeHooks(new FunctionInvokingEvent(this, args)); args = KernelArguments.builder() - .withVariables(invokingEvent.getArguments()) - .withExecutionSettings(this.getExecutionSettings()) - .build(); + .withVariables(invokingEvent.getArguments()) + .withExecutionSettings(this.getExecutionSettings()) + .build(); AIServiceSelection aiServiceSelection = kernel .getServiceSelector() .trySelectAIService( TextAIService.class, + this, args); AIService client = aiServiceSelection != null ? aiServiceSelection.getService() diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/BaseAIServiceSelector.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/BaseAIServiceSelector.java index 03ecdd78..b022bf35 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/BaseAIServiceSelector.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/BaseAIServiceSelector.java @@ -38,8 +38,8 @@ public AIServiceSelection trySelectAIService( @Override @Nullable public AIServiceSelection trySelectAIService( - Class serviceType, - @Nullable KernelArguments arguments) { + Class serviceType, + @Nullable KernelArguments arguments) { return trySelectAIService(serviceType, arguments, services); } @@ -64,7 +64,6 @@ protected abstract AIServiceSelection trySelectAIServic @Nullable KernelArguments arguments, Map, AIService> services); - /** * Resolves an {@link AIService} from the {@code services} argument using the specified * {@code function} and {@code arguments} for selection. @@ -79,9 +78,9 @@ protected abstract AIServiceSelection trySelectAIServic */ @Nullable protected AIServiceSelection trySelectAIService( - Class serviceType, - @Nullable KernelArguments arguments, - Map, AIService> services) { + Class serviceType, + @Nullable KernelArguments arguments, + Map, AIService> services) { return trySelectAIService(serviceType, null, arguments, services); } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/OrderedAIServiceSelector.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/OrderedAIServiceSelector.java index 6b04fe2b..30828233 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/OrderedAIServiceSelector.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/OrderedAIServiceSelector.java @@ -66,16 +66,16 @@ public AIServiceSelection trySelectAIService( Map, AIService> services) { if (function == null) { - return selectAIService(serviceType, arguments != null ? arguments.getExecutionSettings() : null); + return selectAIService(serviceType, + arguments != null ? arguments.getExecutionSettings() : null); } return selectAIService(serviceType, function.getExecutionSettings()); } - private AIServiceSelection selectAIService( - Class serviceType, - @Nullable Map executionSettings) { + Class serviceType, + @Nullable Map executionSettings) { if (executionSettings == null || executionSettings.isEmpty()) { AIService service = getAnyService(serviceType); @@ -84,50 +84,50 @@ private AIServiceSelection selectAIService( } } else { AIServiceSelection selection = executionSettings - .entrySet() - .stream() - .map(keyValue -> { - - PromptExecutionSettings settings = keyValue.getValue(); - String serviceId = keyValue.getKey(); - - if (!Verify.isNullOrEmpty(serviceId)) { - AIService service = getService(serviceId); - if (service != null) { - return castServiceSelection( - new AIServiceSelection<>(service, settings)); - } + .entrySet() + .stream() + .map(keyValue -> { + + PromptExecutionSettings settings = keyValue.getValue(); + String serviceId = keyValue.getKey(); + + if (!Verify.isNullOrEmpty(serviceId)) { + AIService service = getService(serviceId); + if (service != null) { + return castServiceSelection( + new AIServiceSelection<>(service, settings)); } + } - return null; - }) - .filter(Objects::nonNull) - .findFirst() - .orElseGet(() -> null); + return null; + }) + .filter(Objects::nonNull) + .findFirst() + .orElseGet(() -> null); if (selection != null) { return castServiceSelection(selection); } selection = executionSettings - .entrySet() - .stream() - .map(keyValue -> { - PromptExecutionSettings settings = keyValue.getValue(); - - if (!Verify.isNullOrEmpty(settings.getModelId())) { - AIService service = getServiceByModelId(settings.getModelId()); - if (service != null) { - return castServiceSelection( - new AIServiceSelection<>(service, settings)); - } + .entrySet() + .stream() + .map(keyValue -> { + PromptExecutionSettings settings = keyValue.getValue(); + + if (!Verify.isNullOrEmpty(settings.getModelId())) { + AIService service = getServiceByModelId(settings.getModelId()); + if (service != null) { + return castServiceSelection( + new AIServiceSelection<>(service, settings)); } + } - return null; - }) - .filter(Objects::nonNull) - .findFirst() - .orElseGet(() -> null); + return null; + }) + .filter(Objects::nonNull) + .findFirst() + .orElseGet(() -> null); if (selection != null) { return castServiceSelection(selection); diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/TextAIService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/TextAIService.java index 09b3ea1f..3eee32d3 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/TextAIService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/TextAIService.java @@ -29,5 +29,5 @@ public interface TextAIService extends AIService { * future and/or made configurable should need arise. *

*/ - int MAXIMUM_INFLIGHT_AUTO_INVOKES = 5; + int MAXIMUM_INFLIGHT_AUTO_INVOKES = 128; } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatMessageContent.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatMessageContent.java index 0408860e..e06472d1 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatMessageContent.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatMessageContent.java @@ -60,17 +60,17 @@ public ChatMessageContent( * @param content the content */ public ChatMessageContent( - AuthorRole authorRole, - String authorName, - String content) { + AuthorRole authorRole, + String authorName, + String content) { this( - authorRole, - authorName, - content, - null, - null, - null, - null); + authorRole, + authorName, + content, + null, + null, + null, + null); } /**