diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentBinding.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentBinding.cs index 4897189d90..84f04bd8dc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentBinding.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentBinding.cs @@ -6,16 +6,27 @@ namespace Microsoft.Agents.AI.Workflows; /// -/// Represents the workflow binding details for an AI agent, including configuration options for event emission. +/// Represents the workflow binding details for an AI agent, including configuration options for agent hosting behaviour. /// /// The AI agent. -/// Specifies whether the agent should emit events. If null, the default behavior is applied. -public record AIAgentBinding(AIAgent Agent, bool EmitEvents = false) +/// The options for configuring the AI agent host. +/// +public record AIAgentBinding(AIAgent Agent, AIAgentHostOptions? Options = null) : ExecutorBinding(Throw.IfNull(Agent).GetDescriptiveId(), - (_) => new(new AIAgentHostExecutor(Agent, EmitEvents)), + (_) => new(new AIAgentHostExecutor(Agent, Options ?? new())), typeof(AIAgentHostExecutor), Agent) { + /// + /// Initializes a new instance of the AIAgentBinding class, associating it with the specified AI agent and + /// optionally enabling event emission. + /// + /// The AI agent. + /// Specifies whether the agent should emit events. If null, the default behavior is applied. + public AIAgentBinding(AIAgent agent, bool emitEvents = false) + : this(agent, new AIAgentHostOptions { EmitAgentRunUpdateEvents = emitEvents }) + { } + /// public override bool IsSharedInstance => false; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentHostOptions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentHostOptions.cs new file mode 100644 index 0000000000..5c8db7698c --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentHostOptions.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows; + +/// +/// . +/// +public sealed class AIAgentHostOptions +{ + /// + /// Gets or sets a value indicating whether agent streaming update events should be emitted during execution. + /// If , the value will be taken from the /> + /// + public bool? EmitAgentRunUpdateEvents { get; set; } + + /// + /// Gets or sets a value indicating whether aggregated agent response events should be emitted during execution. + /// + public bool EmitAgentRunResponseEvents { get; set; } + + /// + /// Gets or sets a value indicating whether should be intercepted and sent + /// as a message to the workflow for handling, instead of being raised as a request. + /// + public bool InterceptUserInputRequests { get; set; } + + /// + /// Gets or sets a value indicating whether without a corresponding + /// should be intercepted and sent as a message to the workflow for handling, + /// instead of being raised as a request. + /// + public bool InterceptUnterminatedFunctionCalls { get; set; } + + /// + /// Gets or sets a value indicating whether other messages from other agents should be assigned to the + /// role during execution. + /// + public bool ReassignOtherAgentsAsUsers { get; set; } = true; + + /// + /// Gets or sets a value indicating whether incoming messages are automatically forwarded before new messages generated + /// by the agent during its turn. + /// + public bool ForwardIncomingMessages { get; set; } = true; +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs index 9f9906270e..4a363f908e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs @@ -19,6 +19,29 @@ public static ChatMessage ToChatMessage(this AgentRunResponseUpdate update) => RawRepresentation = update.RawRepresentation ?? update, }; + public static ChatMessage ChatAssistantToUserIfNotFromNamed(this ChatMessage message, string agentName) + => message.ChatAssistantToUserIfNotFromNamed(agentName, out _, false); + + private static ChatMessage ChatAssistantToUserIfNotFromNamed(this ChatMessage message, string agentName, out bool changed, bool inplace = true) + { + changed = false; + + if (message.Role == ChatRole.Assistant && + message.AuthorName != agentName && + message.Contents.All(c => c is TextContent or DataContent or UriContent or UsageContent)) + { + if (!inplace) + { + message = message.Clone(); + } + + message.Role = ChatRole.User; + changed = true; + } + + return message; + } + /// /// Iterates through looking for messages and swapping /// any that have a different from to @@ -29,11 +52,9 @@ public static ChatMessage ToChatMessage(this AgentRunResponseUpdate update) => List? roleChanged = null; foreach (var m in messages) { - if (m.Role == ChatRole.Assistant && - m.AuthorName != targetAgentName && - m.Contents.All(c => c is TextContent or DataContent or UriContent or UsageContent)) + m.ChatAssistantToUserIfNotFromNamed(targetAgentName, out bool changed); + if (changed) { - m.Role = ChatRole.User; (roleChanged ??= []).Add(m); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AgentWorkflowBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AgentWorkflowBuilder.cs index c5272e39ea..e2046d41e5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AgentWorkflowBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AgentWorkflowBuilder.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Specialized; @@ -35,38 +34,28 @@ public static Workflow BuildSequential(string workflowName, params IEnumerable agents) { - Throw.IfNull(agents); + Throw.IfNullOrEmpty(agents); // Create a builder that chains the agents together in sequence. The workflow simply begins // with the first agent in the sequence. - WorkflowBuilder? builder = null; - ExecutorBinding? previous = null; - foreach (var agent in agents) + + AIAgentHostOptions options = new() { - AgentRunStreamingExecutor agentExecutor = new(agent, includeInputInOutput: true); - - if (builder is null) - { - builder = new WorkflowBuilder(agentExecutor); - } - else - { - Debug.Assert(previous is not null); - builder.AddEdge(previous, agentExecutor); - } - - previous = agentExecutor; - } + ReassignOtherAgentsAsUsers = true, + ForwardIncomingMessages = true, + }; + + List agentExecutors = agents.Select(agent => agent.BindAsExecutor(options)).ToList(); - if (previous is null) + ExecutorBinding previous = agentExecutors[0]; + WorkflowBuilder builder = new(previous); + + foreach (ExecutorBinding next in agentExecutors.Skip(1)) { - Throw.ArgumentException(nameof(agents), "At least one agent must be provided to build a sequential workflow."); + builder.AddEdge(previous, next); + previous = next; } - // Add an ending executor that batches up all messages from the last agent - // so that it's published as a single list result. - Debug.Assert(builder is not null); - OutputMessagesExecutor end = new(); builder = builder.AddEdge(previous, end).WithOutputFrom(end); if (workflowName is not null) @@ -125,9 +114,12 @@ private static Workflow BuildConcurrentCore( // so that the final accumulator receives a single list of messages from each agent. Otherwise, the // accumulator would not be able to determine what came from what agent, as there's currently no // provenance tracking exposed in the workflow context passed to a handler. - ExecutorBinding[] agentExecutors = (from agent in agents select (ExecutorBinding)new AgentRunStreamingExecutor(agent, includeInputInOutput: false)).ToArray(); - ExecutorBinding[] accumulators = [.. from agent in agentExecutors select (ExecutorBinding)new CollectChatMessagesExecutor($"Batcher/{agent.Id}")]; + + ExecutorBinding[] agentExecutors = (from agent in agents + select agent.BindAsExecutor(new AIAgentHostOptions() { ReassignOtherAgentsAsUsers = true })).ToArray(); + ExecutorBinding[] accumulators = [.. from agent in agentExecutors select (ExecutorBinding)new AggregateTurnMessagesExecutor($"Batcher/{agent.Id}")]; builder.AddFanOutEdge(start, agentExecutors); + for (int i = 0; i < agentExecutors.Length; i++) { builder.AddEdge(agentExecutors[i], accumulators[i]); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs index 238734b598..c1f4566cc0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ChatProtocolExecutor.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -18,6 +19,12 @@ public class ChatProtocolExecutorOptions /// If set, the executor will accept string messages and convert them to chat messages with this role. /// public ChatRole? StringMessageChatRole { get; set; } + + /// + /// Gets or sets a value indicating whether the executor should automatically send the + /// after returning from + /// + public bool AutoSendTurnToken { get; set; } = true; } /// @@ -26,8 +33,8 @@ public class ChatProtocolExecutorOptions /// public abstract class ChatProtocolExecutor : StatefulExecutor> { - private static readonly Func> s_initFunction = () => []; - private readonly ChatRole? _stringMessageChatRole; + internal static readonly Func> s_initFunction = () => []; + private readonly ChatProtocolExecutorOptions _options; /// /// Initializes a new instance of the class. @@ -38,16 +45,28 @@ public abstract class ChatProtocolExecutor : StatefulExecutor> protected ChatProtocolExecutor(string id, ChatProtocolExecutorOptions? options = null, bool declareCrossRunShareable = false) : base(id, () => [], declareCrossRunShareable: declareCrossRunShareable) { - this._stringMessageChatRole = options?.StringMessageChatRole; + this._options = options ?? new(); } + /// + /// Gets a value indicating whether string-based messages are by this . + /// + [MemberNotNullWhen(true, nameof(StringMessageChatRole))] + protected bool SupportsStringMessage => this.StringMessageChatRole.HasValue; + + /// + protected ChatRole? StringMessageChatRole => this._options.StringMessageChatRole; + + /// + protected bool AutoSendTurnToken => this._options.AutoSendTurnToken; + /// protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) { - if (this._stringMessageChatRole.HasValue) + if (this.SupportsStringMessage) { routeBuilder = routeBuilder.AddHandler( - (message, context) => this.AddMessageAsync(new(this._stringMessageChatRole.Value, message), context)); + (message, context) => this.AddMessageAsync(new(this.StringMessageChatRole.Value, message), context)); } return routeBuilder.AddHandler(this.AddMessageAsync) @@ -111,7 +130,10 @@ public ValueTask TakeTurnAsync(TurnToken token, IWorkflowContext context, Cancel await this.TakeTurnAsync(maybePendingMessages ?? s_initFunction(), context, token.EmitEvents, cancellationToken) .ConfigureAwait(false); - await context.SendMessageAsync(token, cancellationToken: cancellationToken).ConfigureAwait(false); + if (this.AutoSendTurnToken) + { + await context.SendMessageAsync(token, cancellationToken: cancellationToken).ConfigureAwait(false); + } // Rerun the initialStateFactory to reset the state to empty list. (We could return the empty list directly, // but this is more consistent if the initial state factory becomes more complex.) @@ -119,6 +141,28 @@ await this.TakeTurnAsync(maybePendingMessages ?? s_initFunction(), context, toke } } + /// + /// Processes the current set of turn messages using the specified asynchronous processing function. + /// + /// If the provided list of chat messages is null, an initial empty list is supplied to the + /// processing function. If the processing function returns null, an empty list is used as the result. + /// A delegate that asynchronously processes a list of chat messages within the given workflow context and + /// cancellation token, returning the processed list of chat messages or null. + /// The workflow context in which the messages are processed. + /// A token that can be used to cancel the asynchronous operation. + /// A ValueTask that represents the asynchronous operation. The result contains the processed list of chat messages, + /// or an empty list if the processing function returns null. + protected ValueTask ProcessTurnMessagesAsync(Func, IWorkflowContext, CancellationToken, ValueTask?>> processFunc, IWorkflowContext context, CancellationToken cancellationToken) + { + return this.InvokeWithStateAsync(InvokeProcessFuncAsync, context, cancellationToken: cancellationToken); + + async ValueTask?> InvokeProcessFuncAsync(List? maybePendingMessages, IWorkflowContext context, CancellationToken cancellationToken) + { + return (await processFunc(maybePendingMessages ?? s_initFunction(), context, cancellationToken).ConfigureAwait(false)) + ?? s_initFunction(); + } + } + /// /// When overridden in a derived class, processes the accumulated chat messages for a single turn. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs index 952f9c4748..3cc0e6e6a1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/EdgeMap.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -12,7 +13,7 @@ internal sealed class EdgeMap { private readonly Dictionary _edgeRunners = []; private readonly Dictionary _statefulRunners = []; - private readonly Dictionary _portEdgeRunners; + private readonly ConcurrentDictionary _portEdgeRunners; private readonly ResponseEdgeRunner _inputRunner; private readonly IStepTracer? _stepTracer; @@ -51,12 +52,16 @@ public EdgeMap(IRunnerContext runContext, } } - this._portEdgeRunners = workflowPorts.ToDictionary( - port => port.Id, - port => ResponseEdgeRunner.ForPort(runContext, port) - ); + this._portEdgeRunners = new(); + foreach (RequestPort port in workflowPorts) + { + if (!this.TryRegisterPort(runContext, port.Id, port)) + { + throw new InvalidOperationException($"Duplicate port ID detected: {port.Id}"); + } + } - this._inputRunner = new ResponseEdgeRunner(runContext, startExecutorId); + this._inputRunner = new ResponseEdgeRunner(runContext, startExecutorId, ""); this._stepTracer = stepTracer; } @@ -71,6 +76,9 @@ public EdgeMap(IRunnerContext runContext, return edgeRunner.ChaseEdgeAsync(message, this._stepTracer); } + public bool TryRegisterPort(IRunnerContext runContext, string executorId, RequestPort port) + => this._portEdgeRunners.TryAdd(port.Id, ResponseEdgeRunner.ForPort(runContext, executorId, port)); + public ValueTask PrepareDeliveryForInputAsync(MessageEnvelope message) { return this._inputRunner.ChaseEdgeAsync(message, this._stepTracer); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/IRunnerContext.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/IRunnerContext.cs index f3fc762336..e84080c6a7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/IRunnerContext.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/IRunnerContext.cs @@ -12,6 +12,6 @@ internal interface IRunnerContext : IExternalRequestSink, ISuperStepJoinContext ValueTask SendMessageAsync(string sourceId, object message, string? targetId = null, CancellationToken cancellationToken = default); ValueTask AdvanceAsync(CancellationToken cancellationToken = default); - IWorkflowContext Bind(string executorId, Dictionary? traceContext = null); + IWorkflowContext BindWorkflowContext(string executorId, Dictionary? traceContext = null); ValueTask EnsureExecutorAsync(string executorId, IStepTracer? tracer, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs index 55e85b8b14..deab3bad52 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Execution/ResponseEdgeRunner.cs @@ -8,17 +8,19 @@ namespace Microsoft.Agents.AI.Workflows.Execution; -internal sealed class ResponseEdgeRunner(IRunnerContext runContext, string sinkId) +internal sealed class ResponseEdgeRunner(IRunnerContext runContext, string executorId, string sinkId) : EdgeRunner(runContext, sinkId) { - public static ResponseEdgeRunner ForPort(IRunnerContext runContext, RequestPort port) + public static ResponseEdgeRunner ForPort(IRunnerContext runContext, string executorId, RequestPort port) { Throw.IfNull(port); // The port is an request port, so we can use the port's ID as the sink ID. - return new ResponseEdgeRunner(runContext, port.Id); + return new ResponseEdgeRunner(runContext, executorId, port.Id); } + public string ExecutorId => executorId; + protected internal override async ValueTask ChaseEdgeAsync(MessageEnvelope envelope, IStepTracer? stepTracer) { Debug.Assert(envelope.IsExternal, "Input edges should only be chased from external input"); @@ -27,7 +29,7 @@ public static ResponseEdgeRunner ForPort(IRunnerContext runContext, RequestPort activity? .SetTag(Tags.EdgeGroupType, nameof(ResponseEdgeRunner)) .SetTag(Tags.MessageSourceId, envelope.SourceId) - .SetTag(Tags.MessageTargetId, this.EdgeData); + .SetTag(Tags.MessageTargetId, $"{this.ExecutorId}[{this.EdgeData}]"); try { @@ -48,5 +50,5 @@ public static ResponseEdgeRunner ForPort(IRunnerContext runContext, RequestPort } } - private async ValueTask FindExecutorAsync(IStepTracer? tracer) => await this.RunContext.EnsureExecutorAsync(this.EdgeData, tracer).ConfigureAwait(false); + private async ValueTask FindExecutorAsync(IStepTracer? tracer) => await this.RunContext.EnsureExecutorAsync(this.ExecutorId, tracer).ConfigureAwait(false); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs index 647dbcd852..feb2d8c219 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Executor.cs @@ -16,7 +16,7 @@ namespace Microsoft.Agents.AI.Workflows; /// /// A component that processes messages in a . /// -[DebuggerDisplay("{GetType().Name}{Id}")] +[DebuggerDisplay("{GetType().Name}[{Id}]")] public abstract class Executor : IIdentified { /// @@ -63,6 +63,24 @@ protected Executor(string id, ExecutorOptions? options = null, bool declareCross /// protected abstract RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder); + internal void Configure(IExternalRequestContext externalRequestContext) + { + // TODO: This is an unfortunate pattern (pending the ability to rework the Configure APIs a bit): + // new() + // >>> will throw InvalidOperationException if Configure() is not invoked when using PortHandlers + // .Configure() + // >>> only usable now + // The fix would be to change the API surface of Executor to have Configure return the contract that the workflow + // will use to invoke the executor (currently the MessageRouter). (Ideally we would rename Executor to Node or similar, + // and the actual Executor class will represent that Contract object) + // Not a terrible issue right now because only InProcessExecution exists right now, and the InProccessRunContext centralizes + // executor instantiation in EnsureExecutorAsync. + this.Router = this.CreateRouter(externalRequestContext); + } + + private MessageRouter CreateRouter(IExternalRequestContext? externalRequestContext = null) + => this.ConfigureRoutes(new RouteBuilder(externalRequestContext)).Build(); + /// /// Perform any asynchronous initialization required by the executor. This method is called once per executor instance, /// @@ -99,12 +117,15 @@ internal MessageRouter Router { if (field is null) { - RouteBuilder routeBuilder = this.ConfigureRoutes(new RouteBuilder()); - field = routeBuilder.Build(); + field = this.CreateRouter(); } return field; } + private set + { + field = value; + } } /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ExecutorBindingExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ExecutorBindingExtensions.cs index 5a5e197541..edaf959ba7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/ExecutorBindingExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ExecutorBindingExtensions.cs @@ -419,9 +419,18 @@ public static ExecutorBinding BindAsExecutor(this FuncThe agent instance. /// Specifies whether the agent should emit streaming events. /// An instance that wraps the provided agent. - public static ExecutorBinding BindAsExecutor(this AIAgent agent, bool emitEvents = false) + public static ExecutorBinding BindAsExecutor(this AIAgent agent, bool emitEvents) => new AIAgentBinding(agent, emitEvents); + /// + /// Configure an as an executor for use in a workflow. + /// + /// The agent instance. + /// Optional configuration options for the AI agent executor. If null, default options are used. + /// An instance that wraps the provided agent. + public static ExecutorBinding BindAsExecutor(this AIAgent agent, AIAgentHostOptions? options = null) + => new AIAgentBinding(agent, options); + /// /// Configure a as an executor for use in a workflow. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/ExternalResponse.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/ExternalResponse.cs index f01668dfa5..a26650cedc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/ExternalResponse.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/ExternalResponse.cs @@ -43,4 +43,13 @@ public record ExternalResponse(RequestPortInfo PortInfo, string RequestId, Porta /// The type to which the data should be cast or converted. /// The data cast to the specified type, or null if the data cannot be cast to the specified type. public object? DataAs(Type targetType) => this.Data.AsType(targetType); + + /// + /// Attempts to retrieve the underlying data as the specified type. + /// + /// The type to which the data should be cast or converted. + /// When this method returns , contains the value of type + /// if the data is available and compatible. + /// true if the data is present and can be cast to ; otherwise, false. + public bool DataIs(Type targetType, [NotNullWhen(true)] out object? value) => this.Data.IsType(targetType, out value); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/GroupChatWorkflowBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/GroupChatWorkflowBuilder.cs index 12b0f9c707..9a09f22617 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/GroupChatWorkflowBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/GroupChatWorkflowBuilder.cs @@ -50,7 +50,14 @@ public GroupChatWorkflowBuilder AddParticipants(params IEnumerable agen public Workflow Build() { AIAgent[] agents = this._participants.ToArray(); - Dictionary agentMap = agents.ToDictionary(a => a, a => (ExecutorBinding)new AgentRunStreamingExecutor(a, includeInputInOutput: true)); + + AIAgentHostOptions options = new() + { + ReassignOtherAgentsAsUsers = true, + ForwardIncomingMessages = true + }; + + Dictionary agentMap = agents.ToDictionary(a => a, a => a.BindAsExecutor(options)); Func> groupChatHostFactory = (id, runId) => new(new GroupChatHost(id, agents, agentMap, this._managerFactory)); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/IExternalRequestContext.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/IExternalRequestContext.cs new file mode 100644 index 0000000000..13dfcaeb31 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/IExternalRequestContext.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Agents.AI.Workflows.Execution; + +namespace Microsoft.Agents.AI.Workflows; + +internal interface IExternalRequestContext +{ + IExternalRequestSink RegisterPort(RequestPort port); +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs index 8c7149b0be..644ab3ec82 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunner.cs @@ -200,7 +200,7 @@ private async ValueTask DeliverMessagesAsync(string receiverId, ConcurrentQueue< await executor.ExecuteAsync( envelope.Message, envelope.MessageType, - this.RunContext.Bind(receiverId, envelope.TraceContext), + this.RunContext.BindWorkflowContext(receiverId, envelope.TraceContext), cancellationToken ).ConfigureAwait(false); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs index 1750f779f2..48d70ec280 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/InProc/InProcessRunnerContext.cs @@ -66,6 +66,16 @@ public InProcessRunnerContext( this.OutgoingEvents = outgoingEvents; } + public IExternalRequestSink RegisterPort(string executorId, RequestPort port) + { + if (!this._edgeMap.TryRegisterPort(this, executorId, port)) + { + throw new InvalidOperationException($"A port with ID {port.Id} already exists."); + } + + return this; + } + public async ValueTask EnsureExecutorAsync(string executorId, IStepTracer? tracer, CancellationToken cancellationToken = default) { this.CheckEnded(); @@ -79,7 +89,9 @@ async Task CreateExecutorAsync(string id) } Executor executor = await registration.CreateInstanceAsync(this._runId).ConfigureAwait(false); - await executor.InitializeAsync(this.Bind(executorId), cancellationToken: cancellationToken) + executor.Configure(this.BindExternalRequestContext(executorId)); + + await executor.InitializeAsync(this.BindWorkflowContext(executorId), cancellationToken: cancellationToken) .ConfigureAwait(false); tracer?.TraceActivated(executorId); @@ -211,10 +223,16 @@ await this._edgeMap.PrepareDeliveryForEdgeAsync(edge, envelope) } } - public IWorkflowContext Bind(string executorId, Dictionary? traceContext = null) + public IExternalRequestContext BindExternalRequestContext(string executorId) { this.CheckEnded(); - return new BoundContext(this, executorId, this._outputFilter, traceContext); + return new BoundExternalRequestContext(this, executorId); + } + + public IWorkflowContext BindWorkflowContext(string executorId, Dictionary? traceContext = null) + { + this.CheckEnded(); + return new BoundWorkflowContext(this, executorId, this._outputFilter, traceContext); } public ValueTask PostAsync(ExternalRequest request) @@ -238,7 +256,17 @@ public bool CompleteRequest(string requestId) internal StateManager StateManager { get; } = new(); - private sealed class BoundContext( + private sealed class BoundExternalRequestContext( + InProcessRunnerContext RunnerContext, + string ExecutorId) : IExternalRequestContext + { + public IExternalRequestSink RegisterPort(RequestPort port) + { + return RunnerContext.RegisterPort(ExecutorId, port); + } + } + + private sealed class BoundWorkflowContext( InProcessRunnerContext RunnerContext, string ExecutorId, OutputFilter outputFilter, @@ -303,7 +331,7 @@ internal Task PrepareForCheckpointAsync(CancellationToken cancellationToken = de async Task InvokeCheckpointingAsync(Task executorTask) { Executor executor = await executorTask.ConfigureAwait(false); - await executor.OnCheckpointingAsync(this.Bind(executor.Id), cancellationToken).ConfigureAwait(false); + await executor.OnCheckpointingAsync(this.BindWorkflowContext(executor.Id), cancellationToken).ConfigureAwait(false); } } @@ -316,7 +344,7 @@ internal Task NotifyCheckpointLoadedAsync(CancellationToken cancellationToken = async Task InvokeCheckpointRestoredAsync(Task executorTask) { Executor executor = await executorTask.ConfigureAwait(false); - await executor.OnCheckpointRestoredAsync(this.Bind(executor.Id), cancellationToken).ConfigureAwait(false); + await executor.OnCheckpointRestoredAsync(this.BindWorkflowContext(executor.Id), cancellationToken).ConfigureAwait(false); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Microsoft.Agents.AI.Workflows.csproj b/dotnet/src/Microsoft.Agents.AI.Workflows/Microsoft.Agents.AI.Workflows.csproj index 7379d9a6ac..fecdb28112 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Microsoft.Agents.AI.Workflows.csproj +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Microsoft.Agents.AI.Workflows.csproj @@ -2,6 +2,7 @@ preview + $(NoWarn);MEAI001 diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/PortBinding.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/PortBinding.cs new file mode 100644 index 0000000000..86ee76cd44 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/PortBinding.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Agents.AI.Workflows.Execution; + +namespace Microsoft.Agents.AI.Workflows; + +// TODO: Move this onto IExternalRequestSink? +internal class PortBinding(RequestPort port, IExternalRequestSink sink) +{ + public RequestPort Port => port; + public IExternalRequestSink Sink => sink; + + public ValueTask PostRequestAsync(TRequest request, string? requestId = null, CancellationToken cancellationToken = default) + { + ExternalRequest externalRequest = ExternalRequest.Create(this.Port, request, requestId); + return this.Sink.PostAsync(externalRequest); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs index 99cfdb6992..b4a7d9f62f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/RouteBuilder.cs @@ -22,6 +22,14 @@ System.Threading.Tasks.ValueTask >; +using PortHandlerF = + System.Func< + Microsoft.Agents.AI.Workflows.ExternalResponse, // message + Microsoft.Agents.AI.Workflows.IWorkflowContext, // context + System.Threading.CancellationToken, // cancellation + System.Threading.Tasks.ValueTask + >; + namespace Microsoft.Agents.AI.Workflows; /// @@ -32,10 +40,17 @@ namespace Microsoft.Agents.AI.Workflows; /// public class RouteBuilder { + private readonly IExternalRequestContext? _externalRequestContext; private readonly Dictionary _typedHandlers = []; private readonly Dictionary _outputTypes = []; + private readonly Dictionary _portHandlers = []; private CatchAllF? _catchAll; + internal RouteBuilder(IExternalRequestContext? externalRequestContext) + { + this._externalRequestContext = externalRequestContext; + } + internal RouteBuilder AddHandlerInternal(Type messageType, MessageHandlerF handler, Type? outputType, bool overwrite = false) { Throw.IfNull(messageType); @@ -102,6 +117,44 @@ async ValueTask WrappedHandlerAsync(object message, IWorkflowContext } } + internal RouteBuilder AddPortHandler(string id, Func handler, out PortBinding portBinding, bool overwrite = false) + { + if (this._externalRequestContext == null) + { + throw new InvalidOperationException("An external request context is required to register port handlers."); + } + + RequestPort port = RequestPort.Create(id); + IExternalRequestSink sink = this._externalRequestContext!.RegisterPort(port); + portBinding = new(port, sink); + + if (this._portHandlers.ContainsKey(id) == overwrite) + { + this._portHandlers[id] = InvokeHandlerAsync; + } + else if (overwrite) + { + throw new InvalidOperationException($"A handler for port id {id} is not registered (overwrite = true)."); + } + else + { + throw new InvalidOperationException($"A handler for port id {id} is already registered (overwrite = false)."); + } + + return this; + + async ValueTask InvokeHandlerAsync(ExternalResponse response, IWorkflowContext context, CancellationToken cancellationToken) + { + if (!response.DataIs(out TResponse? typedResponse)) + { + throw new InvalidOperationException($"Received response data is not of expected type {typeof(TResponse).FullName} for port {port.Id}."); + } + + await handler(typedResponse, context, cancellationToken).ConfigureAwait(false); + return response; + } + } + /// /// Registers a handler for messages of the specified input type in the workflow route. /// @@ -279,7 +332,7 @@ public RouteBuilder AddHandler(Func WrappedHandlerAsync(object message, IWorkflowContext context, CancellationToken cancellationToken) { - TResult result = await handler.Invoke((TInput)message, context, cancellationToken).ConfigureAwait(false); + TResult result = await handler((TInput)message, context, cancellationToken).ConfigureAwait(false); return CallResult.ReturnResult(result); } } @@ -514,5 +567,29 @@ ValueTask WrappedHandlerAsync(PortableValue message, IWorkflowContex } } - internal MessageRouter Build() => new(this._typedHandlers, [.. this._outputTypes.Values], this._catchAll); + private void RegisterPortHandlerRouter() + { + Dictionary portHandlers = this._portHandlers; + this.AddHandler(InvokeHandlerAsync); + + ValueTask InvokeHandlerAsync(ExternalResponse response, IWorkflowContext context, CancellationToken cancellationToken) + { + if (portHandlers.TryGetValue(response.PortInfo.PortId, out PortHandlerF? portHandler)) + { + return portHandler(response, context, cancellationToken); + } + + throw new InvalidOperationException($"Unknown port {response.PortInfo}"); + } + } + + internal MessageRouter Build() + { + if (this._portHandlers.Count > 0) + { + this.RegisterPortHandlerRouter(); + } + + return new(this._typedHandlers, [.. this._outputTypes.Values], this._catchAll); + } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs index 0a887013a3..78f4039af5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -8,74 +10,229 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; +internal record AIAgentHostState(JsonElement? ThreadState, bool? CurrentTurnEmitEvents); + internal sealed class AIAgentHostExecutor : ChatProtocolExecutor { - private readonly bool _emitEvents; private readonly AIAgent _agent; + private readonly AIAgentHostOptions _options; + private AgentThread? _thread; + private bool? _currentTurnEmitEvents; + + private AIContentExternalHandler? _userInputHandler; + private AIContentExternalHandler? _functionCallHandler; + + private static readonly ChatProtocolExecutorOptions s_defaultChatProtocolOptions = new() + { + AutoSendTurnToken = false, + StringMessageChatRole = ChatRole.User + }; - public AIAgentHostExecutor(AIAgent agent, bool emitEvents = false) : base(id: agent.GetDescriptiveId()) + public AIAgentHostExecutor(AIAgent agent, AIAgentHostOptions options) : base(id: agent.GetDescriptiveId(), + s_defaultChatProtocolOptions, + declareCrossRunShareable: false) // Explicitly false, because we maintain turn state on the instance { this._agent = agent; - this._emitEvents = emitEvents; + this._options = options; + } + + private RouteBuilder ConfigureUserInputRoutes(RouteBuilder routeBuilder) + { + this._userInputHandler = new AIContentExternalHandler( + ref routeBuilder, + portId: $"{this.Id}_UserInput", + intercepted: this._options.InterceptUserInputRequests, + handler: this.HandleUserInputResponseAsync); + + this._functionCallHandler = new AIContentExternalHandler( + ref routeBuilder, + portId: $"{this.Id}_FunctionCall", + intercepted: this._options.InterceptUnterminatedFunctionCalls, + handler: this.HandleFunctionResultAsync); + + return routeBuilder; + } + + protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + { + routeBuilder = base.ConfigureRoutes(routeBuilder); + return this.ConfigureUserInputRoutes(routeBuilder); + } + + private ValueTask HandleUserInputResponseAsync( + UserInputResponseContent response, + IWorkflowContext context, + CancellationToken cancellationToken) + { + if (!this._userInputHandler!.MarkRequestAsHandled(response.Id)) + { + throw new InvalidOperationException($"No pending UserInputRequest found with id '{response.Id}'."); + } + + // TODO: Are there any issues with taking an implicit turn immediately? + List implicitTurnMessages = [new ChatMessage(ChatRole.User, [response])]; + return this.ContinueTurnAsync(implicitTurnMessages, context, this._currentTurnEmitEvents ?? false, cancellationToken); } + private ValueTask HandleFunctionResultAsync( + FunctionResultContent result, + IWorkflowContext context, + CancellationToken cancellationToken) + { + if (!this._functionCallHandler!.MarkRequestAsHandled(result.CallId)) + { + throw new InvalidOperationException($"No pending UserInputRequest found with id '{result.CallId}'."); + } + + List implicitTurnMessages = [new ChatMessage(ChatRole.Tool, [result])]; + return this.ContinueTurnAsync(implicitTurnMessages, context, this._currentTurnEmitEvents ?? false, cancellationToken); + } + + public bool ShouldEmitStreamingEvents(bool? emitEvents) + => emitEvents ?? this._options.EmitAgentRunUpdateEvents ?? false; + private AgentThread EnsureThread(IWorkflowContext context) => this._thread ??= this._agent.GetNewThread(); - private const string ThreadStateKey = nameof(_thread); + private const string UserInputRequestStateKey = nameof(_userInputHandler); + private const string FunctionCallRequestStateKey = nameof(_functionCallHandler); + private const string AIAgentHostStateKey = nameof(AIAgentHostState); + protected internal override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default) { - Task threadTask = Task.CompletedTask; - if (this._thread is not null) - { - JsonElement threadValue = this._thread.Serialize(); - threadTask = context.QueueStateUpdateAsync(ThreadStateKey, threadValue, cancellationToken: cancellationToken).AsTask(); - } + AIAgentHostState state = new(this._thread?.Serialize(), this._currentTurnEmitEvents); + Task coreStateTask = context.QueueStateUpdateAsync(AIAgentHostStateKey, state, cancellationToken: cancellationToken).AsTask(); + Task userInputRequestsTask = this._userInputHandler?.OnCheckpointingAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task functionCallRequestsTask = this._functionCallHandler?.OnCheckpointingAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; Task baseTask = base.OnCheckpointingAsync(context, cancellationToken).AsTask(); - await Task.WhenAll(threadTask, baseTask).ConfigureAwait(false); + await Task.WhenAll(coreStateTask, userInputRequestsTask, functionCallRequestsTask, baseTask).ConfigureAwait(false); } protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default) { - JsonElement? threadValue = await context.ReadStateAsync(ThreadStateKey, cancellationToken: cancellationToken).ConfigureAwait(false); - if (threadValue.HasValue) + Task userInputRestoreTask = this._userInputHandler?.OnCheckpointRestoredAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task functionCallRestoreTask = this._functionCallHandler?.OnCheckpointRestoredAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + + AIAgentHostState? state = await context.ReadStateAsync(AIAgentHostStateKey, cancellationToken: cancellationToken).ConfigureAwait(false); + if (state != null) { - this._thread = this._agent.DeserializeThread(threadValue.Value); + this._thread = state.ThreadState.HasValue ? this._agent.DeserializeThread(state.ThreadState.Value) : null; + this._currentTurnEmitEvents = state.CurrentTurnEmitEvents; } + await Task.WhenAll(userInputRestoreTask, functionCallRestoreTask).ConfigureAwait(false); await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false); } - protected override async ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) + private bool HasOutstandingRequests => (this._userInputHandler?.HasPendingRequests == true) + || (this._functionCallHandler?.HasPendingRequests == true); + + // While we save this on the instance, we are not cross-run shareable, but as AgentBinding uses the factory pattern this is not an issue + private async ValueTask ContinueTurnAsync(List messages, IWorkflowContext context, bool emitEvents, CancellationToken cancellationToken) { - if (emitEvents ?? this._emitEvents) + this._currentTurnEmitEvents = emitEvents; + if (this._options.ForwardIncomingMessages) + { + await context.SendMessageAsync(messages, cancellationToken).ConfigureAwait(false); + } + + IEnumerable filteredMessages = this._options.ReassignOtherAgentsAsUsers + ? messages.Select(m => m.ChatAssistantToUserIfNotFromNamed(this._agent.Name ?? this._agent.Id)) + : messages; + + AgentRunResponse response = await this.InvokeAgentAsync(filteredMessages, context, emitEvents, cancellationToken).ConfigureAwait(false); + + await context.SendMessageAsync(response.Messages is List list ? list : response.Messages.ToList(), cancellationToken) + .ConfigureAwait(false); + + // If we have no outstanding requests, we can yield a turn token back to the workflow. + if (!this.HasOutstandingRequests) { + await context.SendMessageAsync(new TurnToken(this._currentTurnEmitEvents), cancellationToken).ConfigureAwait(false); + this._currentTurnEmitEvents = null; // Possibly not actually necessary, but cleaning this up makes it clearer when debugging + } + } + + protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) + => this.ContinueTurnAsync(messages, context, this.ShouldEmitStreamingEvents(emitEvents), cancellationToken); + + private async ValueTask InvokeAgentAsync(IEnumerable messages, IWorkflowContext context, bool emitEvents, CancellationToken cancellationToken = default) + { +#pragma warning disable MEAI001 + Dictionary userInputRequests = new(); + Dictionary functionCalls = new(); + AgentRunResponse response; + + if (emitEvents) + { +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Run the agent in streaming mode only when agent run update events are to be emitted. IAsyncEnumerable agentStream = this._agent.RunStreamingAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken); List updates = []; - await foreach (AgentRunResponseUpdate update in agentStream.ConfigureAwait(false)) { await context.AddEventAsync(new AgentRunUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); - - // TODO: FunctionCall request handling, and user info request handling. - // In some sense: We should just let it be handled as a ChatMessage, though we should consider - // providing some mechanisms to help the user complete the request, or route it out of the - // workflow. + ExtractUnservicedRequests(update.Contents); updates.Add(update); } - await context.SendMessageAsync(updates.ToAgentRunResponse().Messages, cancellationToken: cancellationToken).ConfigureAwait(false); + response = updates.ToAgentRunResponse(); } else { // Otherwise, run the agent in non-streaming mode. - AgentRunResponse response = await this._agent.RunAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken).ConfigureAwait(false); - await context.SendMessageAsync(response.Messages, cancellationToken: cancellationToken).ConfigureAwait(false); + response = await this._agent.RunAsync(messages, this.EnsureThread(context), cancellationToken: cancellationToken).ConfigureAwait(false); + ExtractUnservicedRequests(response.Messages.SelectMany(message => message.Contents)); + } + + if (this._options.EmitAgentRunResponseEvents == true) + { + await context.AddEventAsync(new AgentRunResponseEvent(this.Id, response), cancellationToken).ConfigureAwait(false); + } + + if (userInputRequests.Count > 0 || functionCalls.Count > 0) + { + Task userInputTask = this._userInputHandler?.ProcessRequestContentsAsync(userInputRequests, context, cancellationToken) ?? Task.CompletedTask; + Task functionCallTask = this._functionCallHandler?.ProcessRequestContentsAsync(functionCalls, context, cancellationToken) ?? Task.CompletedTask; + + await Task.WhenAll(userInputTask, functionCallTask) + .ConfigureAwait(false); + } + + return response; + + void ExtractUnservicedRequests(IEnumerable contents) + { + foreach (AIContent content in contents) + { + if (content is UserInputRequestContent userInputRequest) + { + // It is an error to simultaneously have multiple outstanding user input requests with the same ID. + userInputRequests.Add(userInputRequest.Id, userInputRequest); + } + else if (content is UserInputResponseContent userInputResponse) + { + // If the set of messages somehow already has a corresponding user input response, remove it. + _ = userInputRequests.Remove(userInputResponse.Id); + } + else if (content is FunctionCallContent functionCall) + { + // For function calls, we emit an event to notify the workflow. + // + // possiblity 1: this will be handled inline by the agent abstraction + // possiblity 2: this will not be handlined inline by the agent abstraction + functionCalls.Add(functionCall.CallId, functionCall); + } + else if (content is FunctionResultContent functionResult) + { + _ = functionCalls.Remove(functionResult.CallId); + } + } } +#pragma warning restore MEAI001 } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIContentExternalHandler.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIContentExternalHandler.cs new file mode 100644 index 0000000000..eae1fd90f5 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIContentExternalHandler.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.Specialized; + +internal sealed class AIContentExternalHandler + where TRequestContent : AIContent + where TResponseContent : AIContent +{ + private readonly PortBinding? _portBinding; + private ConcurrentDictionary _pendingRequests = new(); + + public AIContentExternalHandler(ref RouteBuilder routeBuilder, string portId, bool intercepted, Func handler) + { + if (intercepted) + { + this._portBinding = null; + routeBuilder = routeBuilder.AddHandler(handler); + } + else + { + routeBuilder = routeBuilder.AddPortHandler(portId, handler, out this._portBinding); + } + } + + public bool HasPendingRequests => !this._pendingRequests.IsEmpty; + + public Task ProcessRequestContentsAsync(Dictionary requests, IWorkflowContext context, CancellationToken cancellationToken = default) + { + IEnumerable requestTasks = from string requestId in requests.Keys + select this.ProcessRequestContentAsync(requestId, requests[requestId], context, cancellationToken) + .AsTask(); + + return Task.WhenAll(requestTasks); + } + + public ValueTask ProcessRequestContentAsync(string id, TRequestContent requestContent, IWorkflowContext context, CancellationToken cancellationToken = default) + { + if (!this._pendingRequests.TryAdd(id, requestContent)) + { + throw new InvalidOperationException($"A pending request with ID '{id}' already exists."); + } + + return this.IsIntercepted + ? context.SendMessageAsync(requestContent, cancellationToken: cancellationToken) + : this._portBinding.PostRequestAsync(requestContent, id, cancellationToken); + } + + public bool MarkRequestAsHandled(string id) + { + return this._pendingRequests.TryRemove(id, out _); + } + + [MemberNotNullWhen(false, nameof(_portBinding))] + private bool IsIntercepted => this._portBinding == null; + + private static string MakeKey(string id) => $"{id}_PendingRequests"; + + public async ValueTask OnCheckpointingAsync(string id, IWorkflowContext context, CancellationToken cancellationToken = default) + { + Dictionary pendingRequestsCopy = new(this._pendingRequests); + await context.QueueStateUpdateAsync(MakeKey(id), pendingRequestsCopy, cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + + public async ValueTask OnCheckpointRestoredAsync(string id, IWorkflowContext context, CancellationToken cancellationToken = default) + { + Dictionary? loadedState = + await context.ReadStateAsync>(MakeKey(id), cancellationToken: cancellationToken) + .ConfigureAwait(false); + + if (loadedState != null) + { + this._pendingRequests = new ConcurrentDictionary(loadedState); + } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs deleted file mode 100644 index ae3a932feb..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AgentRunStreamingExecutor.cs +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI.Workflows.Specialized; - -/// -/// Executor that runs the agent and forwards all messages, input and output, to the next executor. -/// -internal sealed class AgentRunStreamingExecutor(AIAgent agent, bool includeInputInOutput) - : ChatProtocolExecutor(agent.GetDescriptiveId(), DefaultOptions, declareCrossRunShareable: true), IResettableExecutor -{ - private static ChatProtocolExecutorOptions DefaultOptions => new() - { - StringMessageChatRole = ChatRole.User - }; - - protected override async ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) - { - List? roleChanged = messages.ChangeAssistantToUserForOtherParticipants(agent.Name ?? agent.Id); - - List updates = []; - await foreach (var update in agent.RunStreamingAsync(messages, cancellationToken: cancellationToken).ConfigureAwait(false)) - { - updates.Add(update); - if (emitEvents is true) - { - await context.AddEventAsync(new AgentRunUpdateEvent(this.Id, update), cancellationToken).ConfigureAwait(false); - } - } - - roleChanged.ResetUserToAssistantForChangedRoles(); - - List result = includeInputInOutput ? [.. messages] : []; - result.AddRange(updates.ToAgentRunResponse().Messages); - - await context.SendMessageAsync(result, cancellationToken: cancellationToken).ConfigureAwait(false); - } - - public new ValueTask ResetAsync() => base.ResetAsync(); -} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/CollectChatMessagesExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AggregateTurnMessagesExecutor.cs similarity index 73% rename from dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/CollectChatMessagesExecutor.cs rename to dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AggregateTurnMessagesExecutor.cs index 5a923b9c52..6e7f83d14b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/CollectChatMessagesExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AggregateTurnMessagesExecutor.cs @@ -8,10 +8,10 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; /// -/// Provides an executor that batches received chat messages that it then releases when +/// Provides an executor that aggregates received chat messages that it then releases when /// receiving a . /// -internal sealed class CollectChatMessagesExecutor(string id) : ChatProtocolExecutor(id, declareCrossRunShareable: true), IResettableExecutor +internal sealed class AggregateTurnMessagesExecutor(string id) : ChatProtocolExecutor(id, declareCrossRunShareable: true), IResettableExecutor { /// protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs index 932cf297a3..3dda4a85c6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestInfoExecutor.cs @@ -112,18 +112,11 @@ public async ValueTask HandleAsync(object message, IWorkflowCon public async ValueTask HandleAsync(ExternalResponse message, IWorkflowContext context, CancellationToken cancellationToken = default) { - Throw.IfNull(message); - Throw.IfNull(message.Data); - - if (message.PortInfo.PortId != this.Port.Id) + if (!this.Port.IsResponsePort(message)) { return null; } - object data = message.DataAs(this.Port.Response) ?? - throw new InvalidOperationException( - $"Message type {message.Data.TypeId} is not assignable to the response type {this.Port.Response.Name} of input port {this.Port.Id}."); - if (this._allowWrapped && this._wrappedRequests.TryGetValue(message.RequestId, out ExternalRequest? originalRequest)) { await context.SendMessageAsync(originalRequest.RewrapResponse(message), cancellationToken: cancellationToken).ConfigureAwait(false); @@ -133,6 +126,11 @@ public async ValueTask HandleAsync(object message, IWorkflowCon await context.SendMessageAsync(message, cancellationToken: cancellationToken).ConfigureAwait(false); } + if (!message.Data.IsType(this.Port.Response, out object? data)) + { + throw this.Port.CreateExceptionForType(message); + } + await context.SendMessageAsync(data, cancellationToken: cancellationToken).ConfigureAwait(false); return message; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestPortExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestPortExtensions.cs new file mode 100644 index 0000000000..ec128749b7 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/RequestPortExtensions.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI.Workflows.Specialized; + +internal static class RequestPortExtensions +{ + /// + /// Attempts to process the incoming as a response to a request sent + /// through the specified . If the response is to a different port, returns + /// . If the port matches, but the response data cannot be interpreted as the + /// expected response type, throws an . Otherwise, returns + /// . + /// + /// The request port through which the original request was sent. + /// The candidate response to be processed + /// if the response is for the specified port and the data could be + /// interpreted as the expected response type; otherwise, . + /// Thrown if the response is for the specified port, + /// but the data could not be interpreted as the expected response type. + public static bool ShouldProcessResponse(this RequestPort port, ExternalResponse response) + { + Throw.IfNull(response); + Throw.IfNull(response.Data); + + if (!port.IsResponsePort(response)) + { + return false; + } + + if (!response.Data.IsType(port.Response)) + { + throw port.CreateExceptionForType(response); + } + + return true; + } + + internal static bool IsResponsePort(this RequestPort port, ExternalResponse response) + => Throw.IfNull(response).PortInfo.PortId == port.Id; + + internal static InvalidOperationException CreateExceptionForType(this RequestPort port, ExternalResponse response) + => new($"Message type {response.Data.TypeId} is not assignable to the response type {port.Response.Name}" + + $" of input port {port.Id}."); +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs index 456838b9eb..7b48a306ec 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Workflow.cs @@ -7,6 +7,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Checkpointing; +using Microsoft.Agents.AI.Workflows.Execution; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI.Workflows; @@ -179,6 +180,16 @@ internal async ValueTask ReleaseOwnershipAsync(object ownerToken) await this.TryResetExecutorRegistrationsAsync().ConfigureAwait(false); } + private sealed class NoOpExternalRequestContext : IExternalRequestContext, IExternalRequestSink + { + public ValueTask PostAsync(ExternalRequest request) => default; + + IExternalRequestSink IExternalRequestContext.RegisterPort(RequestPort port) + { + return this; + } + } + /// /// Retrieves a defining how to interact with this workflow. /// @@ -190,6 +201,8 @@ public async ValueTask DescribeProtocolAsync(CancellationTok ExecutorBinding startExecutorRegistration = this.ExecutorBindings[this.StartExecutorId]; Executor startExecutor = await startExecutorRegistration.CreateInstanceAsync(string.Empty) .ConfigureAwait(false); + startExecutor.Configure(new NoOpExternalRequestContext()); + return startExecutor.DescribeProtocol(); } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowsJsonUtilities.cs index d8241f4681..7ffc0c8d7c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowsJsonUtilities.cs @@ -93,7 +93,7 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(TurnToken))] // Built-in Executor State Types - [JsonSerializable(typeof(AIAgentHostExecutor))] + [JsonSerializable(typeof(AIAgentHostState))] // Event Types //[JsonSerializable(typeof(WorkflowEvent))] diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AIAgentHostExecutorTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AIAgentHostExecutorTests.cs new file mode 100644 index 0000000000..4a57f6cb57 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AIAgentHostExecutorTests.cs @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Specialized; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +public class AIAgentHostExecutorTests +{ + private const string TestAgentId = nameof(TestAgentId); + private const string TestAgentName = nameof(TestAgentName); + + private static readonly string[] s_messageStrings = [ + "", + "Hello world!", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "Quisque dignissim ante odio, at facilisis orci porta a. Duis mi augue, fringilla eu egestas a, pellentesque sed lacus." + ]; + + private static List TestMessages => TestReplayAgent.ToChatMessages(s_messageStrings); + + [Theory] + [InlineData(null, null)] + [InlineData(null, true)] + [InlineData(null, false)] + [InlineData(true, null)] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, null)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task Test_AgentHostExecutor_EmitsStreamingUpdatesIFFConfiguredAsync(bool? executorSetting, bool? turnSetting) + { + // Arrange + TestRunContext testContext = new(); + TestReplayAgent agent = new(TestMessages, TestAgentId, TestAgentName); + AIAgentHostExecutor executor = new(agent, new() { EmitAgentRunUpdateEvents = executorSetting }); + testContext.ConfigureExecutor(executor); + + // Act + await executor.TakeTurnAsync(new(turnSetting), testContext.BindWorkflowContext(executor.Id)); + + // Assert + // The rules are: TurnToken overrides Agent, if set. Default to false, if both unset. + bool expectingEvents = turnSetting ?? executorSetting ?? false; + + AgentRunUpdateEvent[] updates = testContext.Events.OfType().ToArray(); + if (expectingEvents) + { + // The way TestReplayAgent is set up, it will emit one update per non-empty AIContent + List expectedUpdateContents = TestMessages.SelectMany(message => message.Contents).ToList(); + + updates.Should().HaveCount(expectedUpdateContents.Count); + for (int i = 0; i < updates.Length; i++) + { + AgentRunUpdateEvent updateEvent = updates[i]; + AIContent expectedUpdateContent = expectedUpdateContents[i]; + + updateEvent.ExecutorId.Should().Be(agent.GetDescriptiveId()); + + AgentRunResponseUpdate update = updateEvent.Update; + update.AuthorName.Should().Be(TestAgentName); + update.AgentId.Should().Be(TestAgentId); + update.Contents.Should().HaveCount(1); + update.Contents[0].Should().BeEquivalentTo(expectedUpdateContent); + } + } + else + { + updates.Should().BeEmpty(); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Test_AgentHostExecutor_EmitsResponseIFFConfiguredAsync(bool executorSetting) + { + // Arrange + TestRunContext testContext = new(); + TestReplayAgent agent = new(TestMessages, TestAgentId, TestAgentName); + AIAgentHostExecutor executor = new(agent, new() { EmitAgentRunResponseEvents = executorSetting }); + testContext.ConfigureExecutor(executor); + + // Act + await executor.TakeTurnAsync(new(), testContext.BindWorkflowContext(executor.Id)); + + // Assert + AgentRunResponseEvent[] updates = testContext.Events.OfType().ToArray(); + if (executorSetting) + { + updates.Should().HaveCount(1); + + AgentRunResponseEvent responseEvent = updates[0]; + responseEvent.ExecutorId.Should().Be(agent.GetDescriptiveId()); + + AgentRunResponse response = responseEvent.Response; + response.AgentId.Should().Be(TestAgentId); + response.Messages.Should().HaveCount(TestMessages.Count - 1); + + for (int i = 0; i < response.Messages.Count; i++) + { + ChatMessage responseMessage = response.Messages[i]; + ChatMessage expectedMessage = TestMessages[i + 1]; // Skip the first empty message + + responseMessage.AuthorName.Should().Be(TestAgentName); + responseMessage.Text.Should().Be(expectedMessage.Text); + } + } + else + { + updates.Should().BeEmpty(); + } + } + + private static ChatMessage UserMessage => new(ChatRole.User, "Hello from User!") { AuthorName = "User" }; + private static ChatMessage AssistantMessage => new(ChatRole.Assistant, "Hello from Assistant!") { AuthorName = "User" }; + private static ChatMessage TestAgentMessage => new(ChatRole.Assistant, $"Hello from {TestAgentName}!") { AuthorName = TestAgentName }; + + [Theory] + [InlineData(true, true, false, false)] + [InlineData(true, true, false, true)] + [InlineData(true, true, true, false)] + [InlineData(true, true, true, true)] + [InlineData(true, false, false, false)] + [InlineData(true, false, false, true)] + [InlineData(true, false, true, false)] + [InlineData(true, false, true, true)] + [InlineData(false, true, false, false)] + [InlineData(false, true, false, true)] + [InlineData(false, true, true, false)] + [InlineData(false, true, true, true)] + [InlineData(false, false, false, false)] + [InlineData(false, false, false, true)] + [InlineData(false, false, true, false)] + [InlineData(false, false, true, true)] + public async Task Test_AgentHostExecutor_ReassignsRolesIFFConfiguredAsync(bool executorSetting, bool includeUser, bool includeSelfMessages, bool includeOtherMessages) + { + // Arrange + TestRunContext testContext = new(); + RoleCheckAgent agent = new(false, TestAgentId, TestAgentName); + AIAgentHostExecutor executor = new(agent, new() { ReassignOtherAgentsAsUsers = executorSetting }); + testContext.ConfigureExecutor(executor); + + List messages = []; + + if (includeUser) + { + messages.Add(UserMessage); + } + + if (includeSelfMessages) + { + messages.Add(TestAgentMessage); + } + + if (includeOtherMessages) + { + messages.Add(AssistantMessage); + } + + // Act + await executor.Router.RouteMessageAsync(messages, testContext.BindWorkflowContext(executor.Id)); + + Func act = async () => await executor.TakeTurnAsync(new(), testContext.BindWorkflowContext(executor.Id)); + + // Assert + bool shouldThrow = includeOtherMessages && !executorSetting; + + if (shouldThrow) + { + await act.Should().ThrowAsync(); + } + else + { + await act.Should().NotThrowAsync(); + } + } + + [Theory] + [InlineData(true, TestAgentRequestType.FunctionCall)] + [InlineData(false, TestAgentRequestType.FunctionCall)] + //[InlineData(true, TestAgentRequestType.UserInputRequest)] Broken until we support polymorphic routing + [InlineData(false, TestAgentRequestType.UserInputRequest)] + public async Task Test_AgentHostExecutor_InterceptsRequestsIFFConfiguredAsync(bool intercept, TestAgentRequestType requestType) + { + const int UnpairedRequestCount = 2; + const int PairedRequestCount = 3; + + // Arrange + TestRunContext testContext = new(); + TestRequestAgent agent = new(requestType, UnpairedRequestCount, PairedRequestCount, TestAgentId, TestAgentName); + AIAgentHostOptions agentHostOptions = requestType switch + { + TestAgentRequestType.FunctionCall => + new() + { + EmitAgentRunResponseEvents = true, + InterceptUnterminatedFunctionCalls = intercept + }, + TestAgentRequestType.UserInputRequest => + new() + { + EmitAgentRunResponseEvents = true, + InterceptUserInputRequests = intercept + }, + _ => throw new NotSupportedException() + }; + + AIAgentHostExecutor executor = new(agent, agentHostOptions); + testContext.ConfigureExecutor(executor); + + // Act + await executor.TakeTurnAsync(new(), testContext.BindWorkflowContext(executor.Id)); + + // Assert + List responses; + if (intercept) + { + // We expect to have a sent message containing the requests as an ExternalRequest + switch (requestType) + { + case TestAgentRequestType.FunctionCall: + responses = ExtractAndValidateRequestContents(); + break; + case TestAgentRequestType.UserInputRequest: + responses = ExtractAndValidateRequestContents(); + break; + default: + throw new NotSupportedException(); + } + + List ExtractAndValidateRequestContents() where TRequest : AIContent + { + IEnumerable requests = testContext.QueuedMessages.Should().ContainKey(executor.Id) + .WhoseValue + .Select(envelope => envelope.Message as TRequest) + .Where(item => item is not null) + .Select(item => item!); + + return agent.ValidateUnpairedRequests(requests).ToList(); + } + } + else + { + responses = agent.ValidateUnpairedRequests([.. testContext.ExternalRequests]).ToList(); + } + + // Act 2 + foreach (object response in responses.Take(UnpairedRequestCount - 1)) + { + await executor.Router.RouteMessageAsync(response, testContext.BindWorkflowContext(executor.Id)); + } + + // Assert 2 + // Since we are not finished, we expect the agent to not have produced a final response (="Remaining: 1") + AgentRunResponseEvent lastResponseEvent = testContext.Events.OfType().Should().NotBeEmpty() + .And.Subject.Last(); + + lastResponseEvent.Response.Text.Should().Be("Remaining: 1"); + + // Act 3 + object finalResponse = responses.Last(); + await executor.Router.RouteMessageAsync(finalResponse, testContext.BindWorkflowContext(executor.Id)); + + // Assert 3 + // Now that we are finished, we expect the agent to have produced a final response + lastResponseEvent = testContext.Events.OfType().Should().NotBeEmpty() + .And.Subject.Last(); + + lastResponseEvent.Response.Text.Should().Be("Done"); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicPortsExecutor.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicPortsExecutor.cs new file mode 100644 index 0000000000..cbeb13c86b --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicPortsExecutor.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +internal sealed class DynamicPortsExecutor(string id, params IEnumerable ports) : Executor(id) +{ + public Dictionary PortBindings { get; } = new(); + + public ConcurrentDictionary> ReceivedResponses { get; } = new(); + + protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + { + foreach (string portId in ports) + { + routeBuilder = routeBuilder + .AddPortHandler(portId, + (response, context, cancellationToken) => + { + this.ReceivedResponses.GetOrAdd(portId, _ => new()).Enqueue(response); + return default; + }, out PortBinding? binding); + + this.PortBindings[portId] = binding; + } + + return routeBuilder; + } + + public ValueTask PostRequestAsync(string portId, TRequest request, TestRunContext testContext, string? requestId = null) + { + PortBinding binding = this.PortBindings[portId]; + return binding.Sink.PostAsync(ExternalRequest.Create(binding.Port, request, requestId)); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicRequestPortTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicRequestPortTests.cs new file mode 100644 index 0000000000..568bab8120 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/DynamicRequestPortTests.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Checkpointing; +using Microsoft.Agents.AI.Workflows.Execution; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +public class DynamicRequestPortTests +{ + private sealed class RequestPortTestContext + { + private const string PortId = "Port1"; + private const string ExecutorId = "Executor1"; + + public RequestPortTestContext() + { + this.Executor = new(ExecutorId, PortId); + this.Executor.Configure(this.ExternalRequestContext); + } + + public TestRunContext RunContext { get; } = new(); + public ExternalRequestContext ExternalRequestContext { get; } = new(); + + public DynamicPortsExecutor Executor { get; } + + public PortBinding PortBinding => this.Executor.PortBindings[PortId]; + + public ExternalRequest Request => this.ExternalRequestContext.ExternalRequests[0]; + + public static async ValueTask CreateAsync(string requestData = "Request", bool validate = true) + { + RequestPortTestContext result = new(); + + await result.Executor.PostRequestAsync(PortId, requestData, result.RunContext); + + if (validate) + { + result.ExternalRequestContext + .ExternalRequests.Should().HaveCount(1) + .And.AllSatisfy(request => request.PortInfo.Should().Be(result.PortBinding.Port.ToPortInfo())); + } + + return result; + } + + public ValueTask InvokeExecutorWithResponseAsync(ExternalResponse response) + => this.Executor.ExecuteAsync(response, new(typeof(ExternalResponse)), this.RunContext.BindWorkflowContext(this.Executor.Id)); + } + + private sealed class ExternalRequestContext : IExternalRequestContext, IExternalRequestSink + { + public List ExternalRequests { get; } = new(); + + public ValueTask PostAsync(ExternalRequest request) + { + this.ExternalRequests.Add(request); + return default; + } + + public IExternalRequestSink RegisterPort(RequestPort port) + { + return this; + } + } + + [Fact] + public async Task Test_DynamicRequestPort_DeliversExpectedResponseAsync() + { + RequestPortTestContext context = await RequestPortTestContext.CreateAsync(); + + ExternalRequest request = context.Request; + await context.InvokeExecutorWithResponseAsync(request.CreateResponse(13)); + + string portId = request.PortInfo.PortId; + context.Executor.ReceivedResponses.Should().HaveCount(1) + .And.ContainKey(portId); + context.Executor.ReceivedResponses[portId].Should().HaveCount(1); + context.Executor.ReceivedResponses[portId].First().Should().Be(13); + } + + [Fact] + public async Task Test_DynamicRequestPort_ThrowsOnWrongPortAsync() + { + RequestPortTestContext context = await RequestPortTestContext.CreateAsync(); + + ExternalRequest request = context.Request; + ExternalRequest fakeRequest = new(RequestPort.Create("port2").ToPortInfo(), request.RequestId, request.Data); + + Func act = async () => await context.InvokeExecutorWithResponseAsync(fakeRequest.CreateResponse(13)); + (await act.Should().ThrowAsync()) + .WithInnerException(); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs index 5ea4715680..34d477600e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeMapSmokeTests.cs @@ -1,24 +1,95 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Checkpointing; using Microsoft.Agents.AI.Workflows.Execution; +using Microsoft.Agents.AI.Workflows.Specialized; namespace Microsoft.Agents.AI.Workflows.UnitTests; public class EdgeMapSmokeTests { [Fact] - public async Task Test_EdgeMap_MaintainsFanInEdgeStateAsync() + public async Task Test_EdgeMap_RoutesStaticPortAsync() + { + TestRunContext runContext = new(); + + RequestPort staticPort = RequestPort.Create("port1"); + RequestInfoExecutor executor = new(staticPort); + EdgeMap edgeMap = new(runContext, [], [staticPort], executor.Id, null); + + runContext.ConfigureExecutor(executor, edgeMap); + + ExternalResponse responseMessage = new(staticPort.ToPortInfo(), "Request1", new(12)); + + DeliveryMapping? mapping = await edgeMap.PrepareDeliveryForResponseAsync(responseMessage); + mapping.Should().NotBeNull(); + + List deliveries = mapping.Deliveries.ToList(); + deliveries.Should().HaveCount(1).And.AllSatisfy(delivery => delivery.TargetId.Should().Be(executor.Id)); + deliveries[0].Envelope.Message.Should().Be(responseMessage); + } + + [Fact] + public async Task Test_EdgeMap_RoutesDynamicPortAsync() + { + TestRunContext runContext = new(); + + DynamicPortsExecutor executor = new("executor1", "port1", "port2"); + EdgeMap edgeMap = new(runContext, [], [], executor.Id, null); + + runContext.ConfigureExecutor(executor, edgeMap); + + await RunPortTestAsync("port1"); + await RunPortTestAsync("port2"); + + async ValueTask RunPortTestAsync(string portId) + { + PortBinding binding = executor.PortBindings[portId]; + ExternalResponse responseMessage = new(binding.Port.ToPortInfo(), $"RequestFor[{portId}]", new(10)); + + DeliveryMapping? mapping = await edgeMap.PrepareDeliveryForResponseAsync(responseMessage); + mapping.Should().NotBeNull(); + + List deliveries = mapping.Deliveries.ToList(); + deliveries.Should().HaveCount(1).And.AllSatisfy(delivery => delivery.TargetId.Should().Be(executor.Id)); + deliveries[0].Envelope.Message.Should().Be(responseMessage); + } + } + + [Fact] + public async Task Test_EdgeMap_DoesNotRouteUnregisteredPortAsync() { TestRunContext runContext = new(); - runContext.Executors["executor1"] = new ForwardMessageExecutor("executor1"); - runContext.Executors["executor2"] = new ForwardMessageExecutor("executor2"); - runContext.Executors["executor3"] = new ForwardMessageExecutor("executor3"); + RequestPort staticPort = RequestPort.Create("port1"); + RequestInfoExecutor staticExecutor = new(staticPort); + DynamicPortsExecutor executor = new("executor1", "port2", "port3"); + EdgeMap edgeMap = new(runContext, [], [staticPort], executor.Id, null); + + runContext.ConfigureExecutors([staticExecutor, executor], edgeMap); + + await RunPortTestAsync("port4"); + + async ValueTask RunPortTestAsync(string portId) + { + RequestPort fakePort = RequestPort.Create(portId); + + ExternalResponse responseMessage = new(fakePort.ToPortInfo(), $"RequestFor[{portId}]", new(10)); + Func> mappingTask = async () => await edgeMap.PrepareDeliveryForResponseAsync(responseMessage); + await mappingTask.Should().ThrowAsync(); + } + } + + [Fact] + public async Task Test_EdgeMap_MaintainsFanInEdgeStateAsync() + { + TestRunContext runContext = new(); Dictionary> workflowEdges = []; FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0), null); @@ -26,9 +97,15 @@ public async Task Test_EdgeMap_MaintainsFanInEdgeStateAsync() workflowEdges["executor1"] = [fanInEdge]; workflowEdges["executor2"] = [fanInEdge]; - EdgeMap edgeMap = new(runContext, workflowEdges, [], "executor1", null); + runContext.ConfigureExecutors( + [ + new ForwardMessageExecutor("executor1"), + new ForwardMessageExecutor("executor2"), + new ForwardMessageExecutor("executor3") + ], edgeMap); + DeliveryMapping? mapping = await edgeMap.PrepareDeliveryForEdgeAsync(fanInEdge, new("part1", "executor1")); mapping.Should().BeNull(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs index 99cd46dd4b..de23780789 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/EdgeRunnerTests.cs @@ -28,9 +28,11 @@ private static async Task CreateAndRunDirectedEdgeTestAsync(bool? conditionMatch : null; TestRunContext runContext = new(); - - runContext.Executors["executor1"] = new ForwardMessageExecutor("executor1"); - runContext.Executors["executor2"] = new ForwardMessageExecutor("executor2"); + runContext.ConfigureExecutors( + [ + new ForwardMessageExecutor("executor1"), + new ForwardMessageExecutor("executor2") + ]); DirectEdgeData edgeData = new("executor1", "executor2", new EdgeId(0), condition); DirectEdgeRunner runner = new(runContext, edgeData); @@ -78,9 +80,11 @@ private static async Task CreateAndRunFanOutEdgeTestAsync(bool? assignerSelectsE { TestRunContext runContext = new(); - runContext.Executors["executor1"] = new ForwardMessageExecutor("executor1"); - runContext.Executors["executor2"] = new ForwardMessageExecutor("executor2"); - runContext.Executors["executor3"] = new ForwardMessageExecutor("executor3"); + runContext.ConfigureExecutors([ + new ForwardMessageExecutor("executor1"), + new ForwardMessageExecutor("executor2"), + new ForwardMessageExecutor("executor3") + ]); Func>? assigner = assignerSelectsEmpty.HasValue @@ -150,10 +154,11 @@ public async Task Test_FanOutEdgeRunnerAsync() public async Task Test_FanInEdgeRunnerAsync() { TestRunContext runContext = new(); - - runContext.Executors["executor1"] = new ForwardMessageExecutor("executor1"); - runContext.Executors["executor2"] = new ForwardMessageExecutor("executor2"); - runContext.Executors["executor3"] = new ForwardMessageExecutor("executor3"); + runContext.ConfigureExecutors([ + new ForwardMessageExecutor("executor1"), + new ForwardMessageExecutor("executor2"), + new ForwardMessageExecutor("executor3") + ]); FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0), null); FanInEdgeRunner runner = new(runContext, edgeData); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs index 5eb8696221..2c89c2bb6b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs @@ -85,7 +85,7 @@ async ValueTask MessageHandlerAsync(int message, IWorkflowContext workflowContex [Fact] public async Task Test_SpecializedExecutor_InfosAsync() { - await RunExecutorBindingInfoMatchTestAsync(new AIAgentHostExecutor(new TestAgent())); + await RunExecutorBindingInfoMatchTestAsync(new AIAgentHostExecutor(new TestAgent(), new())); await RunExecutorBindingInfoMatchTestAsync(new RequestInfoExecutor(TestRequestPort)); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs new file mode 100644 index 0000000000..e678cc0feb --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +internal sealed class RoleCheckAgent(bool allowOtherAssistantRoles, string? id = null, string? name = null) : AIAgent +{ + protected override string? IdCore => id; + + public override string? Name => name; + + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + => new RoleCheckAgentThread(); + + public override AgentThread GetNewThread() => new RoleCheckAgentThread(); + + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + => this.RunStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (ChatMessage message in messages) + { + if (!allowOtherAssistantRoles && message.Role == ChatRole.Assistant && !(message.AuthorName == null || message.AuthorName == this.Name)) + { + throw new InvalidOperationException($"Message from other assistant role detected: AuthorName={message.AuthorName}"); + } + } + + yield return new AgentRunResponseUpdate(ChatRole.Assistant, "Ok") + { + AgentId = this.Id, + AuthorName = this.Name, + MessageId = Guid.NewGuid().ToString("N"), + ResponseId = Guid.NewGuid().ToString("N") + }; + } + + private sealed class RoleCheckAgentThread : InMemoryAgentThread; +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs index ed9af701c6..4e5db48c3b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs @@ -2,9 +2,6 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -17,102 +14,6 @@ namespace Microsoft.Agents.AI.Workflows.UnitTests; public class SpecializedExecutorSmokeTests { - public class TestAIAgent(List? messages = null, string? id = null, string? name = null) : AIAgent - { - protected override string? IdCore => id; - public override string? Name => name; - - public static List ToChatMessages(params string[] messages) - { - List result = messages.Select(ToMessage).ToList(); - - static ChatMessage ToMessage(string text) - { - if (string.IsNullOrEmpty(text)) - { - return new ChatMessage(ChatRole.Assistant, "") { MessageId = "" }; - } - - string[] splits = text.Split(' '); - for (int i = 0; i < splits.Length - 1; i++) - { - splits[i] += ' '; - } - - List contents = splits.Select(text => new TextContent(text) { RawRepresentation = text }).ToList(); - return new(ChatRole.Assistant, contents) - { - MessageId = Guid.NewGuid().ToString("N"), - RawRepresentation = text, - CreatedAt = DateTime.UtcNow, - }; - } - - return result; - } - - public override AgentThread GetNewThread() - => new TestAgentThread(); - - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => new TestAgentThread(); - - public static TestAIAgent FromStrings(params string[] messages) => - new(ToChatMessages(messages)); - - public List Messages { get; } = Validate(messages) ?? []; - - protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => - Task.FromResult(new AgentRunResponse(this.Messages) - { - AgentId = this.Id, - ResponseId = Guid.NewGuid().ToString("N") - }); - - protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - string responseId = Guid.NewGuid().ToString("N"); - foreach (ChatMessage message in this.Messages) - { - foreach (AIContent content in message.Contents) - { - yield return new AgentRunResponseUpdate() - { - AgentId = this.Id, - MessageId = message.MessageId, - ResponseId = responseId, - Contents = [content], - Role = message.Role, - }; - } - } - } - - private static List? Validate(List? candidateMessages) - { - string? currentMessageId = null; - - if (candidateMessages is not null) - { - foreach (ChatMessage message in candidateMessages) - { - if (currentMessageId is null) - { - currentMessageId = message.MessageId; - } - else if (currentMessageId == message.MessageId) - { - throw new ArgumentException("Duplicate consecutive message ids"); - } - } - } - - return candidateMessages; - } - } - - public sealed class TestAgentThread() : InMemoryAgentThread(); - internal sealed class TestWorkflowContext(string executorId, bool concurrentRunsEnabled = false) : IWorkflowContext { private readonly StateManager _stateManager = new(); @@ -177,10 +78,10 @@ public async Task Test_AIAgentStreamingMessage_AggregationAsync() "Quisque dignissim ante odio, at facilisis orci porta a. Duis mi augue, fringilla eu egestas a, pellentesque sed lacus." ]; - List expected = TestAIAgent.ToChatMessages(MessageStrings); + List expected = TestReplayAgent.ToChatMessages(MessageStrings); - TestAIAgent agent = new(expected); - AIAgentHostExecutor host = new(agent); + TestReplayAgent agent = new(expected); + AIAgentHostExecutor host = new(agent, new()); TestWorkflowContext collectingContext = new(host.Id); @@ -203,8 +104,8 @@ public async Task Test_AIAgent_ExecutorId_Use_Agent_NameAsync() { const string AgentAName = "TestAgentAName"; const string AgentBName = "TestAgentBName"; - TestAIAgent agentA = new(name: AgentAName); - TestAIAgent agentB = new(name: AgentBName); + TestReplayAgent agentA = new(name: AgentAName); + TestReplayAgent agentB = new(name: AgentBName); var workflow = new WorkflowBuilder(agentA).AddEdge(agentA, agentB).Build(); var definition = workflow.ToWorkflowInfo(); @@ -225,8 +126,8 @@ public async Task Test_AIAgent_ExecutorId_Use_Agent_NameAsync() [Fact] public async Task Test_AIAgent_ExecutorId_Use_Agent_ID_When_Name_Not_ProvidedAsync() { - TestAIAgent agentA = new(); - TestAIAgent agentB = new(); + TestReplayAgent agentA = new(); + TestReplayAgent agentB = new(); var workflow = new WorkflowBuilder(agentA).AddEdge(agentA, agentB).Build(); var definition = workflow.ToWorkflowInfo(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs new file mode 100644 index 0000000000..7cebfebffa --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +public class TestReplayAgent(List? messages = null, string? id = null, string? name = null) : AIAgent +{ + protected override string? IdCore => id; + public override string? Name => name; + + public static List ToChatMessages(params string[] messages) + { + List result = messages.Select(ToMessage).ToList(); + + static ChatMessage ToMessage(string text) + { + if (string.IsNullOrEmpty(text)) + { + return new ChatMessage(ChatRole.Assistant, "") { MessageId = "" }; + } + + string[] splits = text.Split(' '); + for (int i = 0; i < splits.Length - 1; i++) + { + splits[i] += ' '; + } + + List contents = splits.Select(text => new TextContent(text) { RawRepresentation = text }).ToList(); + return new(ChatRole.Assistant, contents) + { + MessageId = Guid.NewGuid().ToString("N"), + RawRepresentation = text, + CreatedAt = DateTime.UtcNow, + }; + } + + return result; + } + + public override AgentThread GetNewThread() + => new ReplayAgentThread(); + + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + => new ReplayAgentThread(); + + public static TestReplayAgent FromStrings(params string[] messages) => + new(ToChatMessages(messages)); + + public List Messages { get; } = Validate(messages) ?? []; + + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + => this.RunStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + + protected override async IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string responseId = Guid.NewGuid().ToString("N"); + foreach (ChatMessage message in this.Messages) + { + foreach (AIContent content in message.Contents) + { + yield return new AgentRunResponseUpdate() + { + AgentId = this.Id, + AuthorName = this.Name, + MessageId = message.MessageId, + ResponseId = responseId, + Contents = [content], + Role = message.Role, + }; + } + } + } + + private static List? Validate(List? candidateMessages) + { + string? currentMessageId = null; + + if (candidateMessages is not null) + { + foreach (ChatMessage message in candidateMessages) + { + if (currentMessageId is null) + { + currentMessageId = message.MessageId; + } + else if (currentMessageId == message.MessageId) + { + throw new ArgumentException("Duplicate consecutive message ids"); + } + } + } + + return candidateMessages; + } + + private sealed class ReplayAgentThread() : InMemoryAgentThread(); +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs new file mode 100644 index 0000000000..2cc21d987d --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +internal sealed record TestRequestAgentThreadState(JsonElement ThreadState, Dictionary UnservicedRequests, HashSet ServicedRequests, HashSet PairedRequests); + +public enum TestAgentRequestType +{ + FunctionCall, + UserInputRequest +} + +internal sealed class TestRequestAgent(TestAgentRequestType requestType, int unpairedRequestCount, int pairedRequestCount, string? id, string? name) : AIAgent +{ + public Random RNG { get; set; } = new Random(HashCode.Combine(requestType, nameof(TestRequestAgent))); + + public AgentThread? LastThread { get; set; } + + protected override string? IdCore => id; + public override string? Name => name; + + public override AgentThread GetNewThread() + => requestType switch + { + TestAgentRequestType.FunctionCall => new TestRequestAgentThread(), + TestAgentRequestType.UserInputRequest => new TestRequestAgentThread(), + _ => throw new NotSupportedException(), + }; + + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + => requestType switch + { + TestAgentRequestType.FunctionCall => new TestRequestAgentThread(), + TestAgentRequestType.UserInputRequest => new TestRequestAgentThread(), + _ => throw new NotSupportedException(), + }; + + protected override Task RunCoreAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + => this.RunStreamingAsync(messages, thread, options, cancellationToken).ToAgentRunResponseAsync(cancellationToken); + + private static int[] SampleIndicies(Random rng, int n, int c) + { + int[] result = Enumerable.Range(0, c).ToArray(); + + for (int i = c; i < n; i++) + { + int radix = rng.Next(i); + if (radix < c) + { + result[radix] = i; + } + } + + return result; + } + + private async IAsyncEnumerable RunStreamingAsync( + IRequestResponseStrategy strategy, + IEnumerable messages, + AgentThread? thread = null, + AgentRunOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TRequest : AIContent + where TResponse : AIContent + { + this.LastThread = thread ??= this.GetNewThread(); + TestRequestAgentThread traThread = ConvertThread(thread); + + if (traThread.HasSentRequests) + { + foreach (TResponse response in messages.SelectMany(message => message.Contents).OfType()) + { + strategy.ProcessResponse(response, traThread); + } + + if (traThread.UnservicedRequests.Count == 0) + { + yield return new(ChatRole.Assistant, "Done"); + } + else + { + yield return new(ChatRole.Assistant, $"Remaining: {traThread.UnservicedRequests.Count}"); + } + } + else + { + int totalRequestCount = unpairedRequestCount + pairedRequestCount; + yield return new(ChatRole.Assistant, $"Creating {totalRequestCount} requests, {pairedRequestCount} paired."); + + HashSet servicedIndicies = [.. SampleIndicies(this.RNG, totalRequestCount, pairedRequestCount)]; + + (string, TRequest)[] requests = strategy.CreateRequests(unpairedRequestCount + pairedRequestCount).ToArray(); + List pairedResponses = new(capacity: pairedRequestCount); + + for (int i = 0; i < requests.Length; i++) + { + (string id, TRequest request) = requests[i]; + if (servicedIndicies.Contains(i)) + { + traThread.PairedRequests.Add(id); + pairedResponses.Add(strategy.CreatePairedResponse(request)); + } + else + { + traThread.UnservicedRequests.Add(id, request); + } + + yield return new(ChatRole.Assistant, [request]); + } + + yield return new(ChatRole.Assistant, pairedResponses); + + traThread.HasSentRequests = true; + } + } + + private static TestRequestAgentThread ConvertThread(AgentThread thread) + where TRequest : AIContent + where TResponse : AIContent + { + if (thread is not TestRequestAgentThread traThread) + { + throw new ArgumentException($"Bad AgentThread type: Expected {typeof(TestRequestAgentThread)}, got {thread.GetType()}.", nameof(thread)); + } + + return traThread; + } + + private sealed class FunctionCallStrategy : IRequestResponseStrategy + { + public FunctionResultContent CreatePairedResponse(FunctionCallContent request) + { + return new FunctionResultContent(request.CallId, request); + } + + public IEnumerable<(string, FunctionCallContent)> CreateRequests(int count) + { + for (int i = 0; i < count; i++) + { + string callId = Guid.NewGuid().ToString("N"); + FunctionCallContent request = new(callId, "TestFunction"); + yield return (callId, request); + } + } + + public void ProcessResponse(FunctionResultContent response, TestRequestAgentThread thread) + { + if (thread.UnservicedRequests.TryGetValue(response.CallId, out FunctionCallContent? request)) + { + response.Result.As().Should().Be(request); + thread.ServicedRequests.Add(response.CallId); + thread.UnservicedRequests.Remove(response.CallId); + } + else if (thread.ServicedRequests.Contains(response.CallId)) + { + throw new InvalidOperationException($"Seeing duplicate response with id {response.CallId}"); + } + else if (thread.PairedRequests.Contains(response.CallId)) + { + throw new InvalidOperationException($"Seeing explicit response to initially paired request with id {response.CallId}"); + } + else + { + throw new InvalidOperationException($"Seeing response to nonexistent request with id {response.CallId}"); + } + } + } + + private sealed class FunctionApprovalStrategy : IRequestResponseStrategy + { + public UserInputResponseContent CreatePairedResponse(UserInputRequestContent request) + { + if (request is not FunctionApprovalRequestContent approvalRequest) + { + throw new InvalidOperationException($"Invalid request: Expecting {typeof(FunctionApprovalResponseContent)}, got {request.GetType()}"); + } + + return new FunctionApprovalResponseContent(approvalRequest.Id, true, approvalRequest.FunctionCall); + } + + public IEnumerable<(string, UserInputRequestContent)> CreateRequests(int count) + { + for (int i = 0; i < count; i++) + { + string id = Guid.NewGuid().ToString("N"); + UserInputRequestContent request = new FunctionApprovalRequestContent(id, new(id, "TestFunction")); + yield return (id, request); + } + } + + public void ProcessResponse(UserInputResponseContent response, TestRequestAgentThread thread) + { + if (thread.UnservicedRequests.TryGetValue(response.Id, out UserInputRequestContent? request)) + { + if (request is not FunctionApprovalRequestContent approvalRequest) + { + throw new InvalidOperationException($"Invalid request: Expecting {typeof(FunctionApprovalResponseContent)}, got {request.GetType()}"); + } + + if (response is not FunctionApprovalResponseContent approvalResponse) + { + throw new InvalidOperationException($"Invalid response: Expecting {typeof(FunctionApprovalResponseContent)}, got {response.GetType()}"); + } + + approvalResponse.Approved.Should().BeTrue(); + approvalResponse.FunctionCall.As().Should().Be(approvalRequest.FunctionCall); + thread.ServicedRequests.Add(response.Id); + thread.UnservicedRequests.Remove(response.Id); + } + else if (thread.ServicedRequests.Contains(response.Id)) + { + throw new InvalidOperationException($"Seeing duplicate response with id {response.Id}"); + } + else if (thread.PairedRequests.Contains(response.Id)) + { + throw new InvalidOperationException($"Seeing explicit response to initially paired request with id {response.Id}"); + } + else + { + throw new InvalidOperationException($"Seeing response to nonexistent request with id {response.Id}"); + } + } + } + + private interface IRequestResponseStrategy + where TRequest : AIContent + where TResponse : AIContent + { + IEnumerable<(string, TRequest)> CreateRequests(int count); + TResponse CreatePairedResponse(TRequest request); + + void ProcessResponse(TResponse response, TestRequestAgentThread thread); + } + + protected override IAsyncEnumerable RunCoreStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + { + return requestType switch + { + TestAgentRequestType.FunctionCall => this.RunStreamingAsync(new FunctionCallStrategy(), messages, thread, options, cancellationToken), + TestAgentRequestType.UserInputRequest => this.RunStreamingAsync(new FunctionApprovalStrategy(), messages, thread, options, cancellationToken), + _ => throw new NotSupportedException($"Unknown AgentRequestType {requestType}"), + }; + } + + private static string RetrieveId(TRequest request) + where TRequest : AIContent + { + return request switch + { + FunctionCallContent functionCall => functionCall.CallId, + UserInputRequestContent userInputRequest => userInputRequest.Id, + _ => throw new NotSupportedException($"Unknown request type {typeof(TRequest)}"), + }; + } + + private IEnumerable ValidateUnpairedRequests(IEnumerable requests, IRequestResponseStrategy strategy) + where TRequest : AIContent + where TResponse : AIContent + { + this.LastThread.Should().NotBeNull(); + TestRequestAgentThread traThread = ConvertThread(this.LastThread); + + requests.Should().HaveCount(traThread.UnservicedRequests.Count); + foreach (TRequest request in requests) + { + string requestId = RetrieveId(request); + traThread.UnservicedRequests.Should().ContainKey(requestId); + yield return strategy.CreatePairedResponse(request); + } + } + + internal IEnumerable ValidateUnpairedRequests(IEnumerable requests) + where TRequest : AIContent + { + switch (requestType) + { + case TestAgentRequestType.FunctionCall: + if (typeof(TRequest) != typeof(FunctionCallContent)) + { + throw new ArgumentException($"Invalid request type: Expected {typeof(FunctionCallContent)}, got {typeof(TRequest)}", nameof(requests)); + } + + return this.ValidateUnpairedRequests((IEnumerable)requests, new FunctionCallStrategy()); + case TestAgentRequestType.UserInputRequest: + if (!typeof(UserInputRequestContent).IsAssignableFrom(typeof(TRequest))) + { + throw new ArgumentException($"Invalid request type: Expected {typeof(UserInputRequestContent)}, got {typeof(TRequest)}", nameof(requests)); + } + + return this.ValidateUnpairedRequests((IEnumerable)requests, new FunctionApprovalStrategy()); + default: + throw new NotSupportedException($"Unknown AgentRequestType {requestType}"); + } + } + + internal IEnumerable ValidateUnpairedRequests(List requests) + { + List responses; + switch (requestType) + { + case TestAgentRequestType.FunctionCall: + responses = this.ValidateUnpairedRequests(requests.Select(AssertAndExtractRequestContent)).ToList(); + break; + case TestAgentRequestType.UserInputRequest: + responses = this.ValidateUnpairedRequests(requests.Select(AssertAndExtractRequestContent)).ToList(); + break; + default: + throw new NotSupportedException($"Unknown AgentRequestType {requestType}"); + } + + return Enumerable.Zip(requests, responses, (ExternalRequest request, object response) => request.CreateResponse(response)); + + static TRequest AssertAndExtractRequestContent(ExternalRequest request) + { + request.DataIs(out TRequest? content).Should().BeTrue(); + return content!; + } + } + + private sealed class TestRequestAgentThread : InMemoryAgentThread + where TRequest : AIContent + where TResponse : AIContent + { + public TestRequestAgentThread() + { + } + + public bool HasSentRequests { get; set; } + public Dictionary UnservicedRequests { get; } = new(); + public HashSet ServicedRequests { get; } = new(); + public HashSet PairedRequests { get; } = new(); + + private static JsonElement DeserializeAndExtractState(JsonElement serializedState, + out TestRequestAgentThreadState state, + JsonSerializerOptions? jsonSerializerOptions = null) + { + state = JsonSerializer.Deserialize(serializedState, jsonSerializerOptions) + ?? throw new ArgumentException(""); + + return state.ThreadState; + } + + public TestRequestAgentThread(JsonElement element, JsonSerializerOptions? jsonSerializerOptions = null) + : base(DeserializeAndExtractState(element, out TestRequestAgentThreadState state, jsonSerializerOptions)) + { + this.UnservicedRequests = state.UnservicedRequests.ToDictionary( + keySelector: item => item.Key, + elementSelector: item => item.Value.As()!); + + this.ServicedRequests = state.ServicedRequests; + this.PairedRequests = state.PairedRequests; + } + + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + { + JsonElement threadState = base.Serialize(jsonSerializerOptions); + + Dictionary portableUnservicedRequests = + this.UnservicedRequests.ToDictionary( + keySelector: item => item.Key, + elementSelector: item => new PortableValue(item.Value)); + + TestRequestAgentThreadState state = new(threadState, portableUnservicedRequests, this.ServicedRequests, this.PairedRequests); + + return JsonSerializer.SerializeToElement(state, jsonSerializerOptions); + } + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs index 57375b8341..7535a12c58 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -10,6 +11,36 @@ namespace Microsoft.Agents.AI.Workflows.UnitTests; public class TestRunContext : IRunnerContext { + private sealed class TestExternalRequestContext(IRunnerContext runnerContext, string executorId, EdgeMap? map) : IExternalRequestContext + { + public IExternalRequestSink RegisterPort(RequestPort port) + { + if (map?.TryRegisterPort(runnerContext, executorId, port) == false) + { + throw new InvalidOperationException("Duplicate port id: " + port.Id); + } + + return runnerContext; + } + } + + internal TestRunContext ConfigureExecutor(Executor executor, EdgeMap? map = null) + { + executor.Configure(new TestExternalRequestContext(this, executor.Id, map)); + this.Executors.Add(executor.Id, executor); + return this; + } + + internal TestRunContext ConfigureExecutors(IEnumerable executors, EdgeMap? map = null) + { + foreach (var executor in executors) + { + this.ConfigureExecutor(executor, map); + } + + return this; + } + private sealed class BoundContext( string executorId, TestRunContext runnerContext, @@ -57,13 +88,13 @@ public ValueTask AddEventAsync(WorkflowEvent workflowEvent, CancellationToken ca return default; } - public IWorkflowContext Bind(string executorId, Dictionary? traceContext = null) + public IWorkflowContext BindWorkflowContext(string executorId, Dictionary? traceContext = null) => new BoundContext(executorId, this, traceContext); - public List ExternalRequests { get; } = []; + public ConcurrentQueue ExternalRequests { get; } = []; public ValueTask PostAsync(ExternalRequest request) { - this.ExternalRequests.Add(request); + this.ExternalRequests.Enqueue(request); return default; } @@ -85,8 +116,8 @@ ValueTask IRunnerContext.AdvanceAsync(CancellationToken cancellatio public Dictionary Executors { get; set; } = []; public string StartingExecutorId { get; set; } = string.Empty; - public bool WithCheckpointing => throw new NotSupportedException(); - public bool ConcurrentRunsEnabled => throw new NotSupportedException(); + public bool WithCheckpointing => false; + public bool ConcurrentRunsEnabled => false; ValueTask IRunnerContext.EnsureExecutorAsync(string executorId, IStepTracer? tracer, CancellationToken cancellationToken) => new(this.Executors[executorId]);