Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.

// Uncomment this to enable JSON checkpointing to the local file system.
#define CHECKPOINT_JSON
//#define CHECKPOINT_JSON

using System.Diagnostics;
using System.Reflection;
using System.Text.Json;
using Azure.AI.Agents.Persistent;
using Azure.Identity;
using Microsoft.Agents.AI.Workflows;
#if CHECKPOINT_JSON
using Microsoft.Agents.AI.Workflows.Checkpointing;
#endif
using Microsoft.Agents.AI.Workflows.Declarative;
using Microsoft.Agents.AI.Workflows.Declarative.Events;
using Microsoft.Agents.AI.Workflows.Declarative.Kit;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;

Expand Down Expand Up @@ -63,20 +67,21 @@ private async Task ExecuteAsync()

#if CHECKPOINT_JSON
// Use a file-system based JSON checkpoint store to persist checkpoints to disk.
DirectoryInfo checkpointFolder = Directory.CreateDirectory(Path.Combine(".", $"chk-{DateTime.Now:YYmmdd-hhMMss-ff}"));
DirectoryInfo checkpointFolder = Directory.CreateDirectory(Path.Combine(".", $"chk-{DateTime.Now:yyMMdd-hhmmss-ff}"));
CheckpointManager checkpointManager = CheckpointManager.CreateJson(new FileSystemJsonCheckpointStore(checkpointFolder));
Checkpointed<StreamingRun> run = await InProcessExecution.StreamAsync(workflow, input, checkpointManager);
#else
// Use an in-memory checkpoint store that will not persist checkpoints beyond the lifetime of the process.
CheckpointManager checkpointManager = CheckpointManager.CreateInMemory();
#endif

Checkpointed<StreamingRun> run = await InProcessExecution.StreamAsync(workflow, input, checkpointManager);

bool isComplete = false;
InputResponse? response = null;
object? response = null;
do
{
ExternalRequest? inputRequest = await this.MonitorAndDisposeWorkflowRunAsync(run, response);
if (inputRequest is not null)
ExternalRequest? externalRequest = await this.MonitorAndDisposeWorkflowRunAsync(run, response);
if (externalRequest is not null)
{
Notify("\nWORKFLOW: Yield");

Expand All @@ -86,7 +91,7 @@ private async Task ExecuteAsync()
}

// Process the external request.
response = HandleExternalRequest(inputRequest);
response = await this.HandleExternalRequestAsync(externalRequest);

// Let's resume on an entirely new workflow instance to demonstrate checkpoint portability.
workflow = this.CreateWorkflow();
Expand All @@ -107,11 +112,25 @@ private async Task ExecuteAsync()
Notify("\nWORKFLOW: Done!\n");
}

/// <summary>
/// Create the workflow from the declarative YAML. Includes definition of the
/// <see cref="DeclarativeWorkflowOptions" /> and the associated <see cref="WorkflowAgentProvider"/>.
/// </summary>
/// <remarks>
/// The value assigned to <see cref="IncludeFunctions" /> controls on whether the function
/// tools (<see cref="AIFunction"/>) initialized in the constructor are included for auto-invocation.
/// </remarks>
private Workflow CreateWorkflow()
{
// Use DeclarativeWorkflowBuilder to build a workflow based on a YAML file.
AzureAgentProvider agentProvider = new(this.FoundryEndpoint, new AzureCliCredential())
{
// Functions included here will be auto-executed by the framework.
Functions = IncludeFunctions ? this.FunctionMap.Values : null,
};

DeclarativeWorkflowOptions options =
new(new AzureAgentProvider(this.FoundryEndpoint, new AzureCliCredential()))
new(agentProvider)
{
Configuration = this.Configuration,
//ConversationId = null, // Assign to continue a conversation
Expand All @@ -121,8 +140,18 @@ private Workflow CreateWorkflow()
return DeclarativeWorkflowBuilder.Build<string>(this.WorkflowFile, options);
}

/// <summary>
/// Configuration key used to identify the Foundry project endpoint.
/// </summary>
private const string ConfigKeyFoundryEndpoint = "FOUNDRY_PROJECT_ENDPOINT";

/// <summary>
/// Controls on whether the function tools (<see cref="AIFunction"/>) initialized
/// in the constructor are included for auto-invocation.
/// NOTE: By default, no functions exist as part of this sample.
/// </summary>
private const bool IncludeFunctions = true;

private static Dictionary<string, string> NameCache { get; } = [];
private static HashSet<string> FileCache { get; } = [];

Expand All @@ -132,6 +161,7 @@ private Workflow CreateWorkflow()
private PersistentAgentsClient FoundryClient { get; }
private IConfiguration Configuration { get; }
private CheckpointInfo? LastCheckpoint { get; set; }
private Dictionary<string, AIFunction> FunctionMap { get; }

private Program(string workflowFile, string? workflowInput)
{
Expand All @@ -142,12 +172,21 @@ private Program(string workflowFile, string? workflowInput)

this.FoundryEndpoint = this.Configuration[ConfigKeyFoundryEndpoint] ?? throw new InvalidOperationException($"Undefined configuration setting: {ConfigKeyFoundryEndpoint}");
this.FoundryClient = new PersistentAgentsClient(this.FoundryEndpoint, new AzureCliCredential());

List<AIFunction> functions =
[
// Manually define any custom functions that may be required by agents within the workflow.
// By default, this sample does not include any functions.
//AIFunctionFactory.Create(),
];
this.FunctionMap = functions.ToDictionary(f => f.Name);
}

private async Task<ExternalRequest?> MonitorAndDisposeWorkflowRunAsync(Checkpointed<StreamingRun> run, InputResponse? response = null)
private async Task<ExternalRequest?> MonitorAndDisposeWorkflowRunAsync(Checkpointed<StreamingRun> run, object? response = null)
{
await using IAsyncDisposable disposeRun = run;

bool hasStreamed = false;
string? messageId = null;

await foreach (WorkflowEvent workflowEvent in run.Run.WatchStreamAsync().ConfigureAwait(false))
Expand Down Expand Up @@ -211,11 +250,12 @@ private Program(string workflowFile, string? workflowInput)
case AgentRunUpdateEvent streamEvent:
if (!string.Equals(messageId, streamEvent.Update.MessageId, StringComparison.Ordinal))
{
hasStreamed = false;
messageId = streamEvent.Update.MessageId;

if (messageId is not null)
{
string? agentId = streamEvent.Update.AuthorName;
string? agentId = streamEvent.Update.AgentId;
if (agentId is not null)
{
if (!NameCache.TryGetValue(agentId, out string? realName))
Expand Down Expand Up @@ -245,11 +285,18 @@ private Program(string workflowFile, string? workflowInput)
await DownloadFileContentAsync(Path.GetFileName(messageUpdate.TextAnnotation?.TextToReplace ?? "response.png"), content);
}
break;
case RequiredActionUpdate actionUpdate:
Console.ForegroundColor = ConsoleColor.White;
Console.Write($"Calling tool: {actionUpdate.FunctionName}");
Console.ForegroundColor = ConsoleColor.DarkGray;
Console.WriteLine($" [{actionUpdate.ToolCallId}]");
break;
}
try
{
Console.ResetColor();
Console.Write(streamEvent.Data);
Console.Write(streamEvent.Update.Text);
hasStreamed |= !string.IsNullOrEmpty(streamEvent.Update.Text);
}
finally
{
Expand All @@ -260,7 +307,11 @@ private Program(string workflowFile, string? workflowInput)
case AgentRunResponseEvent messageEvent:
try
{
Console.WriteLine();
if (hasStreamed)
{
Console.WriteLine();
}

if (messageEvent.Response.Usage is not null)
{
Console.ForegroundColor = ConsoleColor.DarkGray;
Expand All @@ -277,14 +328,31 @@ private Program(string workflowFile, string? workflowInput)

return default;
}
private static InputResponse HandleExternalRequest(ExternalRequest request)

/// <summary>
/// Handle request for external input, either from a human or a function tool invocation.
/// </summary>
private async ValueTask<object> HandleExternalRequestAsync(ExternalRequest request) =>
request.Data.TypeId.TypeName switch
{
// Request for human input
_ when request.Data.TypeId.IsMatch<InputRequest>() => HandleInputRequest(request.DataAs<InputRequest>()!),
// Request for function tool invocation. (Only active when functions are defined and IncludeFunctions is true.)
_ when request.Data.TypeId.IsMatch<AgentToolRequest>() => await this.HandleToolRequestAsync(request.DataAs<AgentToolRequest>()!),
// Unknown request type.
_ => throw new InvalidOperationException($"Unsupported external request type: {request.GetType().Name}."),
};

/// <summary>
/// Handle request for human input.
/// </summary>
private static InputResponse HandleInputRequest(InputRequest request)
{
InputRequest? message = request.Data.As<InputRequest>();
string? userInput;
do
{
Console.ForegroundColor = ConsoleColor.DarkGreen;
Console.Write($"\n{message?.Prompt ?? "INPUT:"} ");
Console.Write($"\n{request.Prompt ?? "INPUT:"} ");
Console.ForegroundColor = ConsoleColor.White;
userInput = Console.ReadLine();
}
Expand All @@ -293,6 +361,30 @@ private static InputResponse HandleExternalRequest(ExternalRequest request)
return new InputResponse(userInput);
}

/// <summary>
/// Handle a function tool request by invoking the specified tools and returning the results.
/// </summary>
/// <remarks>
/// This handler is only active when <see cref="IncludeFunctions"/> is set to true and
/// one or more <see cref="AIFunction"/> instances are defined in the constructor.
/// </remarks>
private async ValueTask<AgentToolResponse> HandleToolRequestAsync(AgentToolRequest request)
{
Task<FunctionResultContent>[] functionTasks = request.FunctionCalls.Select(functionCall => InvokesToolAsync(functionCall)).ToArray();

await Task.WhenAll(functionTasks);

return AgentToolResponse.Create(request, functionTasks.Select(task => task.Result));

async Task<FunctionResultContent> InvokesToolAsync(FunctionCallContent functionCall)
{
AIFunction functionTool = this.FunctionMap[functionCall.Name];
AIFunctionArguments? functionArguments = functionCall.Arguments is null ? null : new(functionCall.Arguments.NormalizePortableValues());
object? result = await functionTool.InvokeAsync(functionArguments);
return new FunctionResultContent(functionCall.CallId, JsonSerializer.Serialize(result));
}
}

private static string? ParseWorkflowFile(string[] args)
{
string? workflowFile = args.FirstOrDefault();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,41 @@ IEnumerable<MessageInputContentBlock> GetContent()
}

/// <inheritdoc/>
public override async Task<AIAgent> GetAgentAsync(string agentId, CancellationToken cancellationToken = default) =>
await this.GetAgentsClient().GetAIAgentAsync(agentId, chatOptions: null, clientFactory: null, cancellationToken).ConfigureAwait(false);
public override async Task<AIAgent> GetAgentAsync(string agentId, CancellationToken cancellationToken = default)
{
ChatClientAgent agent =
await this.GetAgentsClient().GetAIAgentAsync(
agentId,
new ChatOptions()
{
AllowMultipleToolCalls = this.AllowMultipleToolCalls,
},
clientFactory: null,
cancellationToken).ConfigureAwait(false);

FunctionInvokingChatClient? functionInvokingClient = agent.GetService<FunctionInvokingChatClient>();
if (functionInvokingClient is not null)
{
// Allow concurrent invocations if configured
functionInvokingClient.AllowConcurrentInvocation = this.AllowConcurrentInvocation;
// Allows the caller to respond with function responses
functionInvokingClient.TerminateOnUnknownCalls = true;
// Make functions available for execution. Doesn't change what tool is available for any given agent.
if (this.Functions is not null)
{
if (functionInvokingClient.AdditionalTools is null)
{
functionInvokingClient.AdditionalTools = [.. this.Functions];
}
else
{
functionInvokingClient.AdditionalTools = [.. functionInvokingClient.AdditionalTools, .. this.Functions];
}
}
}

return agent;
}

/// <inheritdoc/>
public override async Task<ChatMessage> GetMessageAsync(string conversationId, string messageId, CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json.Serialization;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows.Declarative.Events;

/// <summary>
/// Represents a request for user input.
/// </summary>
public sealed class AgentToolRequest
{
/// <summary>
/// The name of the agent associated with the tool request.
/// </summary>
public string AgentName { get; }

/// <summary>
/// A list of tool requests.
/// </summary>
public IList<FunctionCallContent> FunctionCalls { get; }

[JsonConstructor]
internal AgentToolRequest(string agentName, IList<FunctionCallContent>? functionCalls = null)
{
this.AgentName = agentName;
this.FunctionCalls = functionCalls ?? [];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI.Workflows.Declarative.Events;

/// <summary>
/// Represents a user input response.
/// </summary>
public sealed class AgentToolResponse
{
/// <summary>
/// The name of the agent associated with the tool response.
/// </summary>
public string AgentName { get; }

/// <summary>
/// A list of tool responses.
/// </summary>
public IList<FunctionResultContent> FunctionResults { get; }

/// <summary>
/// Initializes a new instance of the <see cref="InputResponse"/> class.
/// </summary>
[JsonConstructor]
internal AgentToolResponse(string agentName, IList<FunctionResultContent> functionResults)
{
this.AgentName = agentName;
this.FunctionResults = functionResults;
}

/// <summary>
/// Factory method to create an <see cref="AgentToolResponse"/> from an <see cref="AgentToolRequest"/>
/// Ensures that all function calls in the request have a corresponding result.
/// </summary>
/// <param name="toolRequest">The tool request.</param>
/// <param name="functionResults">On or more function results</param>
/// <returns>An <see cref="AgentToolResponse"/> that can be provided to the workflow.</returns>
/// <exception cref="DeclarativeActionException">Not all <see cref="AgentToolRequest.FunctionCalls"/> have a corresponding <see cref="FunctionResultContent"/>.</exception>
public static AgentToolResponse Create(AgentToolRequest toolRequest, params IEnumerable<FunctionResultContent> functionResults)
{
HashSet<string> callIds = [.. toolRequest.FunctionCalls.Select(call => call.CallId)];
HashSet<string> resultIds = [.. functionResults.Select(call => call.CallId)];
if (!callIds.SetEquals(resultIds))
{
throw new DeclarativeActionException($"Missing results for: {string.Join(",", callIds.Except(resultIds))}");
}
return new AgentToolResponse(toolRequest.AgentName, [.. functionResults]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ public static ChatMessage ToChatMessage(this RecordDataValue message) =>

public static ChatMessage ToChatMessage(this StringDataValue message) => new(ChatRole.User, message.Value);

public static ChatMessage ToChatMessage(this IEnumerable<FunctionResultContent> functionResults) =>
new(ChatRole.Tool, [.. functionResults]);

public static AdditionalPropertiesDictionary? ToMetadata(this RecordDataValue? metadata)
{
if (metadata is null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ public static RecordDataValue ToRecordValue(this IDictionary value)

IEnumerable<KeyValuePair<string, DataValue>> GetFields()
{
foreach (string key in value.Keys)
foreach (DictionaryEntry entry in value)
{
yield return new KeyValuePair<string, DataValue>(key, value[key].ToDataValue());
yield return new KeyValuePair<string, DataValue>((string)entry.Key, entry.Value.ToDataValue());
}
}
}
Expand Down
Loading
Loading