Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dotnet/src/Microsoft.Agents.AI/AgentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ async Task<string> InvokeAgentAsync(
[Description("Input query to invoke the agent.")] string query,
CancellationToken cancellationToken)
{
var response = await agent.RunAsync(query, thread: thread, cancellationToken: cancellationToken).ConfigureAwait(false);
// Propagate any additional properties from the parent agent's run to the child agent if the parent is using a FunctionInvokingChatClient.
AgentRunOptions? agentRunOptions = FunctionInvokingChatClient.CurrentContext?.Options?.AdditionalProperties is AdditionalPropertiesDictionary dict
? new AgentRunOptions { AdditionalProperties = dict }
: null;

var response = await agent.RunAsync(query, thread: thread, options: agentRunOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
return response.Text;
}

Expand Down
60 changes: 60 additions & 0 deletions dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,48 @@ public async Task CreateFromAgent_InvokeWithComplexResponseFromAgentAsync_Return
Assert.Equal("Complex response", result.ToString());
}

[Fact]
public async Task CreateFromAgent_InvokeWithAdditionalProperties_PropagatesAdditionalPropertiesToChildAgentAsync()
{
// Arrange
var expectedResponse = new AgentResponse
{
AgentId = "agent-123",
ResponseId = "response-456",
CreatedAt = DateTimeOffset.UtcNow,
Messages = { new ChatMessage(ChatRole.Assistant, "Complex response") }
};

var testAgent = new TestAgent("TestAgent", "Test description", expectedResponse);
var aiFunction = testAgent.AsAIFunction();

// Use reflection to set the protected CurrentContext property
var context = new FunctionInvocationContext()
{
Options = new()
{
AdditionalProperties = new AdditionalPropertiesDictionary
{
{ "customProperty1", "value1" },
{ "customProperty2", 42 }
}
}
};
SetFunctionInvokingChatClientCurrentContext(context);

// Act
var arguments = new AIFunctionArguments() { ["query"] = "Test query" };
var result = await aiFunction.InvokeAsync(arguments);

// Assert
Assert.NotNull(result);
Assert.Equal("Complex response", result.ToString());
Assert.NotNull(testAgent.ReceivedAgentRunOptions);
Assert.NotNull(testAgent.ReceivedAgentRunOptions!.AdditionalProperties);
Assert.Equal("value1", testAgent.ReceivedAgentRunOptions!.AdditionalProperties["customProperty1"]);
Assert.Equal(42, testAgent.ReceivedAgentRunOptions!.AdditionalProperties["customProperty2"]);
}

[Theory]
[InlineData("MyAgent", "MyAgent")]
[InlineData("Agent123", "Agent123")]
Expand All @@ -302,6 +344,22 @@ public void CreateFromAgent_SanitizesAgentName(string agentName, string expected
Assert.Equal(expectedFunctionName, result.Name);
}

/// <summary>
/// Uses reflection to set the protected static CurrentContext property on FunctionInvokingChatClient.
/// </summary>
private static void SetFunctionInvokingChatClientCurrentContext(FunctionInvocationContext? context)
{
// Access the private static field _currentContext which is an AsyncLocal<FunctionInvocationContext?>
var currentContextField = typeof(FunctionInvokingChatClient).GetField(
"_currentContext",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);

if (currentContextField?.GetValue(null) is AsyncLocal<FunctionInvocationContext?> asyncLocal)
{
asyncLocal.Value = context;
}
}

/// <summary>
/// Test implementation of AIAgent for testing purposes.
/// </summary>
Expand Down Expand Up @@ -334,6 +392,7 @@ public override ValueTask<AgentThread> DeserializeThreadAsync(JsonElement serial
public override string? Description { get; }

public List<ChatMessage> ReceivedMessages { get; } = [];
public AgentRunOptions? ReceivedAgentRunOptions { get; private set; }
public CancellationToken LastCancellationToken { get; private set; }
public int RunAsyncCallCount { get; private set; }

Expand All @@ -346,6 +405,7 @@ protected override Task<AgentResponse> RunCoreAsync(
this.RunAsyncCallCount++;
this.LastCancellationToken = cancellationToken;
this.ReceivedMessages.AddRange(messages);
this.ReceivedAgentRunOptions = options;

if (this._exceptionToThrow is not null)
{
Expand Down
Loading