diff --git a/dotnet/src/Microsoft.Agents.AI/AgentExtensions.cs b/dotnet/src/Microsoft.Agents.AI/AgentExtensions.cs index 097b789a84..07247b059d 100644 --- a/dotnet/src/Microsoft.Agents.AI/AgentExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI/AgentExtensions.cs @@ -73,7 +73,12 @@ async Task InvokeAgentAsync( [Description("Input query to invoke the agent.")] string query, CancellationToken cancellationToken) { - var response = await agent.RunAsync(query, thread: thread, cancellationToken: cancellationToken).ConfigureAwait(false); + // Propagate any additional properties from the parent agent's run to the child agent if the parent is using a FunctionInvokingChatClient. + AgentRunOptions? agentRunOptions = FunctionInvokingChatClient.CurrentContext?.Options?.AdditionalProperties is AdditionalPropertiesDictionary dict + ? new AgentRunOptions { AdditionalProperties = dict } + : null; + + var response = await agent.RunAsync(query, thread: thread, options: agentRunOptions, cancellationToken: cancellationToken).ConfigureAwait(false); return response.Text; } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs index 43039a7b76..e0c5417674 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs @@ -277,6 +277,48 @@ public async Task CreateFromAgent_InvokeWithComplexResponseFromAgentAsync_Return Assert.Equal("Complex response", result.ToString()); } + [Fact] + public async Task CreateFromAgent_InvokeWithAdditionalProperties_PropagatesAdditionalPropertiesToChildAgentAsync() + { + // Arrange + var expectedResponse = new AgentResponse + { + AgentId = "agent-123", + ResponseId = "response-456", + CreatedAt = DateTimeOffset.UtcNow, + Messages = { new ChatMessage(ChatRole.Assistant, "Complex response") } + }; + + var testAgent = new TestAgent("TestAgent", "Test description", expectedResponse); + var aiFunction = testAgent.AsAIFunction(); + + // Use reflection to set the protected CurrentContext property + var context = new FunctionInvocationContext() + { + Options = new() + { + AdditionalProperties = new AdditionalPropertiesDictionary + { + { "customProperty1", "value1" }, + { "customProperty2", 42 } + } + } + }; + SetFunctionInvokingChatClientCurrentContext(context); + + // Act + var arguments = new AIFunctionArguments() { ["query"] = "Test query" }; + var result = await aiFunction.InvokeAsync(arguments); + + // Assert + Assert.NotNull(result); + Assert.Equal("Complex response", result.ToString()); + Assert.NotNull(testAgent.ReceivedAgentRunOptions); + Assert.NotNull(testAgent.ReceivedAgentRunOptions!.AdditionalProperties); + Assert.Equal("value1", testAgent.ReceivedAgentRunOptions!.AdditionalProperties["customProperty1"]); + Assert.Equal(42, testAgent.ReceivedAgentRunOptions!.AdditionalProperties["customProperty2"]); + } + [Theory] [InlineData("MyAgent", "MyAgent")] [InlineData("Agent123", "Agent123")] @@ -302,6 +344,22 @@ public void CreateFromAgent_SanitizesAgentName(string agentName, string expected Assert.Equal(expectedFunctionName, result.Name); } + /// + /// Uses reflection to set the protected static CurrentContext property on FunctionInvokingChatClient. + /// + private static void SetFunctionInvokingChatClientCurrentContext(FunctionInvocationContext? context) + { + // Access the private static field _currentContext which is an AsyncLocal + var currentContextField = typeof(FunctionInvokingChatClient).GetField( + "_currentContext", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + + if (currentContextField?.GetValue(null) is AsyncLocal asyncLocal) + { + asyncLocal.Value = context; + } + } + /// /// Test implementation of AIAgent for testing purposes. /// @@ -334,6 +392,7 @@ public override ValueTask DeserializeThreadAsync(JsonElement serial public override string? Description { get; } public List ReceivedMessages { get; } = []; + public AgentRunOptions? ReceivedAgentRunOptions { get; private set; } public CancellationToken LastCancellationToken { get; private set; } public int RunAsyncCallCount { get; private set; } @@ -346,6 +405,7 @@ protected override Task RunCoreAsync( this.RunAsyncCallCount++; this.LastCancellationToken = cancellationToken; this.ReceivedMessages.AddRange(messages); + this.ReceivedAgentRunOptions = options; if (this._exceptionToThrow is not null) {