-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathA2AClient.cs
More file actions
202 lines (174 loc) · 8.19 KB
/
A2AClient.cs
File metadata and controls
202 lines (174 loc) · 8.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
using System.Net.ServerSentEvents;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
namespace A2A;
/// <summary>
/// Implementation of A2A client for communicating with agents.
/// </summary>
public sealed class A2AClient : IA2AClient
{
internal static readonly HttpClient s_sharedClient = new();
private readonly HttpClient _httpClient;
private readonly Uri _baseUri;
/// <summary>
/// Initializes a new instance of <see cref="A2AClient"/>.
/// </summary>
/// <param name="baseUrl">The base url of the agent's hosting service.</param>
/// <param name="httpClient">The HTTP client to use for requests.</param>
public A2AClient(Uri baseUrl, HttpClient? httpClient = null)
{
if (baseUrl is null)
{
throw new ArgumentNullException(nameof(baseUrl), "Base URL cannot be null.");
}
_baseUri = baseUrl;
_httpClient = httpClient ?? s_sharedClient;
}
/// <inheritdoc />
public Task<A2AResponse> SendMessageAsync(MessageSendParams taskSendParams, CancellationToken cancellationToken = default) =>
SendRpcRequestAsync(
taskSendParams ?? throw new ArgumentNullException(nameof(taskSendParams)),
A2AMethods.MessageSend,
A2AJsonUtilities.JsonContext.Default.MessageSendParams,
A2AJsonUtilities.JsonContext.Default.A2AResponse,
cancellationToken);
/// <inheritdoc />
public Task<AgentTask> GetTaskAsync(string taskId, CancellationToken cancellationToken = default) =>
SendRpcRequestAsync(
new() { Id = string.IsNullOrEmpty(taskId) ? throw new ArgumentNullException(nameof(taskId)) : taskId },
A2AMethods.TaskGet,
A2AJsonUtilities.JsonContext.Default.TaskIdParams,
A2AJsonUtilities.JsonContext.Default.AgentTask,
cancellationToken);
/// <inheritdoc />
public Task<AgentTask> CancelTaskAsync(TaskIdParams taskIdParams, CancellationToken cancellationToken = default) =>
SendRpcRequestAsync(
taskIdParams ?? throw new ArgumentNullException(nameof(taskIdParams)),
A2AMethods.TaskCancel,
A2AJsonUtilities.JsonContext.Default.TaskIdParams,
A2AJsonUtilities.JsonContext.Default.AgentTask,
cancellationToken);
/// <inheritdoc />
public Task<TaskPushNotificationConfig> SetPushNotificationAsync(TaskPushNotificationConfig pushNotificationConfig, CancellationToken cancellationToken = default) =>
SendRpcRequestAsync(
pushNotificationConfig ?? throw new ArgumentNullException(nameof(pushNotificationConfig)),
A2AMethods.TaskPushNotificationConfigSet,
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig,
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig,
cancellationToken);
/// <inheritdoc />
public Task<TaskPushNotificationConfig> GetPushNotificationAsync(GetTaskPushNotificationConfigParams notificationConfigParams, CancellationToken cancellationToken = default) =>
SendRpcRequestAsync(
notificationConfigParams ?? throw new ArgumentNullException(nameof(notificationConfigParams)),
A2AMethods.TaskPushNotificationConfigGet,
A2AJsonUtilities.JsonContext.Default.GetTaskPushNotificationConfigParams,
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig,
cancellationToken);
/// <inheritdoc />
public IAsyncEnumerable<SseItem<A2AEvent>> SendMessageStreamAsync(MessageSendParams taskSendParams, CancellationToken cancellationToken = default) =>
SendRpcSseRequestAsync(
taskSendParams ?? throw new ArgumentNullException(nameof(taskSendParams)),
A2AMethods.MessageStream,
A2AJsonUtilities.JsonContext.Default.MessageSendParams,
A2AJsonUtilities.JsonContext.Default.A2AEvent,
cancellationToken);
/// <inheritdoc />
public IAsyncEnumerable<SseItem<A2AEvent>> SubscribeToTaskAsync(string taskId, CancellationToken cancellationToken = default) =>
SendRpcSseRequestAsync(
new() { Id = string.IsNullOrEmpty(taskId) ? throw new ArgumentNullException(nameof(taskId)) : taskId },
A2AMethods.TaskSubscribe,
A2AJsonUtilities.JsonContext.Default.TaskIdParams,
A2AJsonUtilities.JsonContext.Default.A2AEvent,
cancellationToken);
private async Task<TOutput> SendRpcRequestAsync<TInput, TOutput>(
TInput jsonRpcParams,
string method,
JsonTypeInfo<TInput> inputTypeInfo,
JsonTypeInfo<TOutput> outputTypeInfo,
CancellationToken cancellationToken) where TOutput : class
{
cancellationToken.ThrowIfCancellationRequested();
using var responseStream = await SendAndReadResponseStreamAsync(
jsonRpcParams,
method,
inputTypeInfo,
"application/json",
cancellationToken).ConfigureAwait(false);
var responseObject = await JsonSerializer.DeserializeAsync(responseStream, A2AJsonUtilities.JsonContext.Default.JsonRpcResponse, cancellationToken).ConfigureAwait(false);
if (responseObject?.Error is { } error)
{
throw new A2AException(error.Message, (A2AErrorCode)error.Code);
}
return responseObject?.Result?.Deserialize(outputTypeInfo) ??
throw new InvalidOperationException("Response does not contain a result.");
}
private async IAsyncEnumerable<SseItem<TOutput>> SendRpcSseRequestAsync<TInput, TOutput>(
TInput jsonRpcParams,
string method,
JsonTypeInfo<TInput> inputTypeInfo,
JsonTypeInfo<TOutput> outputTypeInfo,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
using var responseStream = await SendAndReadResponseStreamAsync(
jsonRpcParams,
method,
inputTypeInfo,
"text/event-stream",
cancellationToken).ConfigureAwait(false);
var sseParser = SseParser.Create(responseStream, (_, data) =>
{
var reader = new Utf8JsonReader(data);
var responseObject = JsonSerializer.Deserialize(ref reader, A2AJsonUtilities.JsonContext.Default.JsonRpcResponse);
if (responseObject?.Error is { } error)
{
throw new A2AException(error.Message, (A2AErrorCode)error.Code);
}
if (responseObject?.Result is null)
{
throw new InvalidOperationException("Failed to deserialize the event: Result is null.");
}
return responseObject.Result.Deserialize(outputTypeInfo) ??
throw new InvalidOperationException("Failed to deserialize the event.");
});
await foreach (var item in sseParser.EnumerateAsync(cancellationToken))
{
yield return item;
}
}
private async ValueTask<Stream> SendAndReadResponseStreamAsync<TInput>(
TInput jsonRpcParams,
string method,
JsonTypeInfo<TInput> inputTypeInfo,
string expectedContentType,
CancellationToken cancellationToken)
{
var response = await _httpClient.SendAsync(new(HttpMethod.Post, _baseUri)
{
Content = new JsonRpcContent(new JsonRpcRequest()
{
Id = Guid.NewGuid().ToString(),
Method = method,
Params = JsonSerializer.SerializeToElement(jsonRpcParams, inputTypeInfo),
})
}, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
try
{
response.EnsureSuccessStatusCode();
if (response.Content.Headers.ContentType?.MediaType != expectedContentType)
{
throw new InvalidOperationException($"Invalid content type. Expected '{expectedContentType}' but got '{response.Content.Headers.ContentType?.MediaType}'.");
}
return await response.Content.ReadAsStreamAsync(
#if NET
cancellationToken
#endif
).ConfigureAwait(false);
}
catch
{
response.Dispose();
throw;
}
}
}