diff --git a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs index 7b714e244578..78eb02795d62 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.net8.0.cs @@ -1597,7 +1597,7 @@ public PersistentAgentsClient(string endpoint, Azure.Core.TokenCredential creden public static partial class PersistentAgentsClientExtensions { public static Microsoft.Extensions.AI.AITool AsAITool(this Azure.AI.Agents.Persistent.ToolDefinition tool) { throw null; } - public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null) { throw null; } + public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) { throw null; } } public static partial class PersistentAgentsExtensions { diff --git a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs index 0eac61d23925..9118137f9f26 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/api/Azure.AI.Agents.Persistent.netstandard2.0.cs @@ -1597,7 +1597,7 @@ public PersistentAgentsClient(string endpoint, Azure.Core.TokenCredential creden public static partial class PersistentAgentsClientExtensions { public static Microsoft.Extensions.AI.AITool AsAITool(this Azure.AI.Agents.Persistent.ToolDefinition tool) { throw null; } - public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null) { throw null; } + public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) { throw null; } } public static partial class PersistentAgentsExtensions { diff --git a/sdk/ai/Azure.AI.Agents.Persistent/assets.json b/sdk/ai/Azure.AI.Agents.Persistent/assets.json index 7756f0f2a102..1339f0cca339 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/assets.json +++ b/sdk/ai/Azure.AI.Agents.Persistent/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "net", "TagPrefix": "net/ai/Azure.AI.Agents.Persistent", - "Tag": "net/ai/Azure.AI.Agents.Persistent_84020b2662" + "Tag": "net/ai/Azure.AI.Agents.Persistent_89f0bef6e6" } diff --git a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsChatClient.cs b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsChatClient.cs index 987d48ae2506..af2b86e30842 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsChatClient.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsChatClient.cs @@ -40,8 +40,13 @@ internal partial class PersistentAgentsChatClient : IChatClient /// Lazily-retrieved agent instance. Used for its properties. private PersistentAgent? _agent; + /// + /// Indicates whether to throw exceptions when content errors are encountered. + /// + private readonly bool _throwOnContentErrors; + /// Initializes a new instance of the class for the specified . - public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, string? defaultThreadId = null) + public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) { Argument.AssertNotNull(client, nameof(client)); Argument.AssertNotNullOrWhiteSpace(agentId, nameof(agentId)); @@ -51,6 +56,7 @@ public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, _defaultThreadId = defaultThreadId; _metadata = new(ProviderName); + _throwOnContentErrors = throwOnContentErrors; } protected PersistentAgentsChatClient() { } @@ -191,6 +197,14 @@ threadRun is not null && switch (ru) { + case RunUpdate rup when rup.Value.Status == RunStatus.Failed && rup.Value.LastError is { } error: + if (_throwOnContentErrors) + { + throw new InvalidOperationException(error.Message) { Data = { ["ErrorCode"] = error.Code } }; + } + ruUpdate.Contents.Add(new ErrorContent(error.Message) { ErrorCode = error.Code, RawRepresentation = error }); + break; + case RequiredActionUpdate rau when rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName: ruUpdate.Contents.Add(new FunctionCallContent( JsonSerializer.Serialize([ru.Value.Id, toolCallId], AgentsChatClientJsonContext.Default.StringArray), diff --git a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsClientExtensions.cs b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsClientExtensions.cs index 83ca6d0bc96a..9c248dfbbcd6 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsClientExtensions.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/src/Custom/PersistentAgentsClientExtensions.cs @@ -22,9 +22,10 @@ public static class PersistentAgentsClientExtensions /// or via the /// property. If no thread ID is provided via either mechanism, a new thread will be created for the request. /// + /// Throws an exception if content errors are returned from the service. Default is true. This is useful to detect errors when tools are misconfigured that otherwise would be unnoticed because those come as a streaming data update. /// An instance configured to interact with the specified agent and thread. - public static IChatClient AsIChatClient(this PersistentAgentsClient client, string agentId, string? defaultThreadId = null) => - new PersistentAgentsChatClient(client, agentId, defaultThreadId); + public static IChatClient AsIChatClient(this PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) => + new PersistentAgentsChatClient(client, agentId, defaultThreadId, throwOnContentErrors); /// Creates an to represent a . /// The tool to wrap as an . diff --git a/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsChatClientTests.cs b/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsChatClientTests.cs index 5b5ccbc1099e..479b479e3c05 100644 --- a/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsChatClientTests.cs +++ b/sdk/ai/Azure.AI.Agents.Persistent/tests/PersistentAgentsChatClientTests.cs @@ -8,6 +8,8 @@ using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; using Azure.Core.TestFramework; using Azure.Identity; using Microsoft.Extensions.AI; @@ -26,6 +28,11 @@ public class PersistentAgentsChatClientTests : RecordedTestBase(() => chatClient.GetService(null)); } + [RecordedTest] + public async Task TestContentErrorHandling() + { + var mockTransport = new MockTransport((request) => + { + return GetResponse(request, emptyRunList: false); + }); + + PersistentAgentsClient client = GetClient(mockTransport); + + IChatClient throwingChatClient = client.AsIChatClient(FakeAgentId, FakeThreadId, throwOnContentErrors: true); + IChatClient nonThrowingChatClient = client.AsIChatClient(FakeAgentId, FakeThreadId, throwOnContentErrors: false); + + var exception = Assert.ThrowsAsync(() => throwingChatClient.GetResponseAsync(new ChatMessage(ChatRole.User, "Get Mike's favourite word"))); + Assert.IsTrue(exception.Message.Contains("wrong-connection-id")); + + var response = await nonThrowingChatClient.GetResponseAsync(new ChatMessage(ChatRole.User, "Get Mike's favourite word")); + var errorContent = response.Messages.SelectMany(m => m.Contents).OfType().Single(); + Assert.IsTrue(errorContent.Message.Contains("wrong-connection-id")); + } + #region Helpers + + private PersistentAgentsClient GetClient(HttpPipelineTransport transport) + { + return new(FakeAgentEndpoint, new MockCredential(), options: new PersistentAgentsAdministrationClientOptions() + { + Transport = transport + }); + } + private class CompositeDisposable : IDisposable { private readonly List _disposables = []; @@ -471,5 +508,81 @@ private static string GetFile([CallerFilePath] string pth = "", string fileName var dirName = Path.GetDirectoryName(pth) ?? ""; return Path.Combine(new string[] { dirName, "TestData", fileName }); } + + private static MockResponse GetResponse(MockRequest request, bool? emptyRunList = true) + { + // Sent by client.Administration.GetAgentAsync(...) method + if (request.Method == RequestMethod.Get && request.Uri.Path == $"/assistants/{FakeAgentId}") + { + return CreateOKMockResponse($$""" + { + "id": "{{FakeAgentId}}" + } + """); + } + // Sent by client.Runs.GetRunsAsync(...) method + else if (request.Method == RequestMethod.Get && request.Uri.Path == $"/threads/{FakeThreadId}/runs") + { + return CreateOKMockResponse($$""" + { + "data": {{(emptyRunList is true + ? "[]" + : $$"""[{"id": "{{FakeRunId}}"}]""")}} + } + """); + } + // Sent by client.Runs.CreateRunStreamingAsync(...) method + else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads/{FakeThreadId}/runs") + { + // Content failure response + return CreateOKMockResponse( + $$$""" + event: thread.run.failed + data: { "id":"{{{FakeRunId}}}","object":"thread.run","created_at":1764170243,"assistant_id":"asst_uYPWW0weSBNqXK3VjgRMkuim","thread_id":"thread_dmz0AZPJtnO9MnAfrzP1AtJ6","status":"failed","started_at":1764170244,"expires_at":null,"cancelled_at":null,"failed_at":1764170244,"completed_at":null,"required_action":null,"last_error":{ "code":"tool_user_error","message":"Error: invalid_tool_input; The specified connection ID 'wrong-connection-id' was not found in the project or account connections. Please verify that the connection id in tool input is correct and exists in the project or account."},"model":"gpt-4o","instructions":"Use the bing grounding tool to answer questions.Use the bing grounding tool to answer questions.","tools":[{ "type":"bing_grounding","bing_grounding":{ "search_configurations":[{ "connection_id":"wrong-connection-id","market":"en-US","set_lang":"en","count":5}]} }],"tool_resources":{ "code_interpreter":{ "file_ids":[]} },"metadata":{ },"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{ "type":"auto","last_messages":null},"incomplete_details":null,"usage":{ "prompt_tokens":0,"completion_tokens":0,"total_tokens":0,"prompt_token_details":{ "cached_tokens":0} },"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: done + data: [DONE] + """ + ); + } + // Sent by client.Threads.CreateThreadAsync(...) method + else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads") + { + return CreateOKMockResponse($$""" + { + "id": "{{FakeThreadId}}" + } + """); + } + // Sent by client.Runs.CancelRunAsync(...) method + else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads/{FakeThreadId}/runs/{FakeRunId}/cancel") + { + return CreateOKMockResponse($$""" + { + "id": "{{FakeThreadId}}" + } + """); + } + // Sent by client.Runs.SubmitToolOutputsToStreamAsync(...) method + else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads//runs/{FakeRunId}/submit_tool_outputs") + { + return CreateOKMockResponse($$""" + { + "data":[{ + "id": "{{FakeRunId}}" + }] + } + """); + } + + throw new InvalidOperationException("Unexpected request"); + } + + private static MockResponse CreateOKMockResponse(string content) + { + var response = new MockResponse(200); + response.SetContent(content); + return response; + } } }