Skip to content

Commit b267755

Browse files
authored
feat: Clean up A2AClient (#27)
* feat: Clean up A2AClient - Add cancellation token - Deduplicate SSE methods - Use ResponseHeadersRead to avoid buffering responses - Use shared HttpClient singleton if no client is provided - Tweak naming * Add back missing check
1 parent a7c4c85 commit b267755

File tree

2 files changed

+113
-85
lines changed

2 files changed

+113
-85
lines changed

src/A2A/Client/A2AClient.cs

Lines changed: 106 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,132 +1,160 @@
11
using System.Net.ServerSentEvents;
2+
using System.Runtime.CompilerServices;
23
using System.Text.Json;
34
using System.Text.Json.Serialization.Metadata;
45

56
namespace A2A;
67

7-
public class A2AClient : IA2AClient
8+
public sealed class A2AClient : IA2AClient
89
{
9-
private readonly HttpClient _client;
10+
private static readonly HttpClient s_sharedClient = new();
1011

11-
public A2AClient(HttpClient client)
12+
private readonly HttpClient _httpClient;
13+
14+
public A2AClient(HttpClient? httpClient = null)
1215
{
13-
_client = client;
16+
_httpClient = httpClient ?? s_sharedClient;
1417
}
1518

16-
public Task<A2AResponse> SendMessageAsync(MessageSendParams taskSendParams) =>
17-
RpcRequest(
19+
public Task<A2AResponse> SendMessageAsync(MessageSendParams taskSendParams, CancellationToken cancellationToken = default) =>
20+
SendRpcRequestAsync(
1821
taskSendParams,
1922
A2AMethods.MessageSend,
2023
A2AJsonUtilities.JsonContext.Default.MessageSendParams,
21-
A2AJsonUtilities.JsonContext.Default.A2AResponse);
24+
A2AJsonUtilities.JsonContext.Default.A2AResponse,
25+
cancellationToken);
2226

23-
public Task<AgentTask> GetTaskAsync(string taskId) =>
24-
RpcRequest(
27+
public Task<AgentTask> GetTaskAsync(string taskId, CancellationToken cancellationToken = default) =>
28+
SendRpcRequestAsync(
2529
new() { Id = taskId },
2630
A2AMethods.TaskGet,
2731
A2AJsonUtilities.JsonContext.Default.TaskIdParams,
28-
A2AJsonUtilities.JsonContext.Default.AgentTask);
32+
A2AJsonUtilities.JsonContext.Default.AgentTask,
33+
cancellationToken);
2934

30-
public Task<AgentTask> CancelTaskAsync(TaskIdParams taskIdParams) =>
31-
RpcRequest(
35+
public Task<AgentTask> CancelTaskAsync(TaskIdParams taskIdParams, CancellationToken cancellationToken = default) =>
36+
SendRpcRequestAsync(
3237
taskIdParams,
3338
A2AMethods.TaskCancel,
3439
A2AJsonUtilities.JsonContext.Default.TaskIdParams,
35-
A2AJsonUtilities.JsonContext.Default.AgentTask);
40+
A2AJsonUtilities.JsonContext.Default.AgentTask,
41+
cancellationToken);
3642

37-
public Task<TaskPushNotificationConfig> SetPushNotificationAsync(TaskPushNotificationConfig pushNotificationConfig) =>
38-
RpcRequest(
43+
public Task<TaskPushNotificationConfig> SetPushNotificationAsync(TaskPushNotificationConfig pushNotificationConfig, CancellationToken cancellationToken = default) =>
44+
SendRpcRequestAsync(
3945
pushNotificationConfig,
4046
"task/pushNotification/set",
4147
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig,
42-
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig);
48+
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig,
49+
cancellationToken);
4350

44-
public Task<TaskPushNotificationConfig> GetPushNotificationAsync(TaskIdParams taskIdParams) =>
45-
RpcRequest(
51+
public Task<TaskPushNotificationConfig> GetPushNotificationAsync(TaskIdParams taskIdParams, CancellationToken cancellationToken = default) =>
52+
SendRpcRequestAsync(
4653
taskIdParams,
4754
"task/pushNotification/get",
4855
A2AJsonUtilities.JsonContext.Default.TaskIdParams,
49-
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig);
56+
A2AJsonUtilities.JsonContext.Default.TaskPushNotificationConfig,
57+
cancellationToken);
58+
59+
public IAsyncEnumerable<SseItem<A2AEvent>> SendMessageStreamAsync(MessageSendParams taskSendParams, CancellationToken cancellationToken = default) =>
60+
SendRpcSseRequestAsync(
61+
taskSendParams,
62+
A2AMethods.MessageStream,
63+
A2AJsonUtilities.JsonContext.Default.MessageSendParams,
64+
A2AJsonUtilities.JsonContext.Default.A2AEvent,
65+
cancellationToken);
5066

51-
public async IAsyncEnumerable<SseItem<A2AEvent>> SendMessageStreamAsync(MessageSendParams taskSendParams)
67+
public IAsyncEnumerable<SseItem<A2AEvent>> ResubscribeToTaskAsync(string taskId, CancellationToken cancellationToken = default) =>
68+
SendRpcSseRequestAsync(
69+
new() { Id = taskId },
70+
A2AMethods.TaskResubscribe,
71+
A2AJsonUtilities.JsonContext.Default.TaskIdParams,
72+
A2AJsonUtilities.JsonContext.Default.A2AEvent,
73+
cancellationToken);
74+
75+
private async Task<TOutput> SendRpcRequestAsync<TInput, TOutput>(
76+
TInput jsonRpcParams,
77+
string method,
78+
JsonTypeInfo<TInput> inputTypeInfo,
79+
JsonTypeInfo<TOutput> outputTypeInfo,
80+
CancellationToken cancellationToken) where TOutput : class
5281
{
53-
var request = new JsonRpcRequest()
54-
{
55-
Id = Guid.NewGuid().ToString(),
56-
Method = A2AMethods.MessageStream,
57-
Params = JsonSerializer.SerializeToElement(taskSendParams, A2AJsonUtilities.JsonContext.Default.MessageSendParams),
58-
};
59-
var response = await _client.SendAsync(new HttpRequestMessage(HttpMethod.Post, "")
60-
{
61-
Content = new JsonRpcContent(request)
62-
});
63-
response.EnsureSuccessStatusCode();
64-
using var stream = await response.Content.ReadAsStreamAsync();
65-
var sseParser = SseParser.Create(stream, (eventType, data) =>
66-
{
67-
var reader = new Utf8JsonReader(data);
68-
return JsonSerializer.Deserialize(ref reader, A2AJsonUtilities.JsonContext.Default.A2AEvent) ?? throw new InvalidOperationException("Failed to deserialize the event.");
69-
});
70-
await foreach (var item in sseParser.EnumerateAsync())
71-
{
72-
yield return item;
73-
}
82+
using var responseStream = await SendAndReadResponseStream(
83+
jsonRpcParams,
84+
method,
85+
inputTypeInfo,
86+
"application/json",
87+
cancellationToken).ConfigureAwait(false);
88+
89+
var responseObject = await JsonSerializer.DeserializeAsync(responseStream, A2AJsonUtilities.JsonContext.Default.JsonRpcResponse, cancellationToken) ??
90+
throw new InvalidOperationException("Failed to deserialize the response.");
91+
92+
return responseObject.Result?.Deserialize(outputTypeInfo) ??
93+
throw new InvalidOperationException("Response does not contain a result.");
7494
}
7595

76-
public async IAsyncEnumerable<SseItem<A2AEvent>> ResubscribeToTaskAsync(string taskId)
96+
private async IAsyncEnumerable<SseItem<TOutput>> SendRpcSseRequestAsync<TInput, TOutput>(
97+
TInput jsonRpcParams,
98+
string method,
99+
JsonTypeInfo<TInput> inputTypeInfo,
100+
JsonTypeInfo<TOutput> outputTypeInfo,
101+
[EnumeratorCancellation] CancellationToken cancellationToken)
77102
{
78-
var request = new JsonRpcRequest()
79-
{
80-
Id = Guid.NewGuid().ToString(),
81-
Method = A2AMethods.TaskResubscribe,
82-
Params = JsonSerializer.SerializeToElement(new TaskIdParams() { Id = taskId }, A2AJsonUtilities.JsonContext.Default.TaskIdParams),
83-
};
84-
var response = await _client.SendAsync(new HttpRequestMessage(HttpMethod.Post, "")
85-
{
86-
Content = new JsonRpcContent(request)
87-
});
88-
response.EnsureSuccessStatusCode();
89-
using var stream = await response.Content.ReadAsStreamAsync();
90-
var sseParser = SseParser.Create(stream, (eventType, data) =>
103+
using var responseStream = await SendAndReadResponseStream(
104+
jsonRpcParams,
105+
method,
106+
inputTypeInfo,
107+
"text/event-stream",
108+
cancellationToken).ConfigureAwait(false);
109+
110+
var sseParser = SseParser.Create(responseStream, (eventType, data) =>
91111
{
92112
var reader = new Utf8JsonReader(data);
93-
return JsonSerializer.Deserialize(ref reader, A2AJsonUtilities.JsonContext.Default.A2AEvent) ?? throw new InvalidOperationException("Failed to deserialize the event.");
113+
return JsonSerializer.Deserialize(ref reader, outputTypeInfo) ?? throw new InvalidOperationException("Failed to deserialize the event.");
94114
});
95-
await foreach (var item in sseParser.EnumerateAsync())
115+
116+
await foreach (var item in sseParser.EnumerateAsync(cancellationToken))
96117
{
97118
yield return item;
98119
}
99120
}
100121

101-
private async Task<TOutput> RpcRequest<TInput, TOutput>(
122+
private async ValueTask<Stream> SendAndReadResponseStream<TInput>(
102123
TInput jsonRpcParams,
103124
string method,
104125
JsonTypeInfo<TInput> inputTypeInfo,
105-
JsonTypeInfo<TOutput> outputTypeInfo) where TOutput : class
126+
string expectedContentType,
127+
CancellationToken cancellationToken)
106128
{
107-
var request = new JsonRpcRequest()
129+
var response = await _httpClient.SendAsync(new(HttpMethod.Post, "")
108130
{
109-
Id = Guid.NewGuid().ToString(),
110-
Method = method,
111-
Params = JsonSerializer.SerializeToElement(jsonRpcParams, inputTypeInfo),
112-
};
131+
Content = new JsonRpcContent(new JsonRpcRequest()
132+
{
133+
Id = Guid.NewGuid().ToString(),
134+
Method = method,
135+
Params = JsonSerializer.SerializeToElement(jsonRpcParams, inputTypeInfo),
136+
})
137+
}, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
113138

114-
using var response = await _client.SendAsync(new HttpRequestMessage(HttpMethod.Post, "")
115-
{
116-
Content = new JsonRpcContent(request)
117-
});
118-
response.EnsureSuccessStatusCode();
119-
if (response.Content.Headers.ContentType?.MediaType != "application/json")
139+
try
120140
{
121-
throw new InvalidOperationException("Invalid content type.");
122-
}
123-
124-
using var responseStream = await response.Content.ReadAsStreamAsync();
141+
response.EnsureSuccessStatusCode();
125142

126-
var responseObject = await JsonSerializer.DeserializeAsync(responseStream, A2AJsonUtilities.JsonContext.Default.JsonRpcResponse) ??
127-
throw new InvalidOperationException("Failed to deserialize the response.");
143+
if (response.Content.Headers.ContentType?.MediaType != expectedContentType)
144+
{
145+
throw new InvalidOperationException($"Invalid content type. Expected '{expectedContentType}' but got '{response.Content.Headers.ContentType?.MediaType}'.");
146+
}
128147

129-
return responseObject.Result?.Deserialize(outputTypeInfo) ??
130-
throw new InvalidOperationException("Response does not contain a result.");
148+
return await response.Content.ReadAsStreamAsync(
149+
#if NET
150+
cancellationToken
151+
#endif
152+
);
153+
}
154+
catch
155+
{
156+
response.Dispose();
157+
throw;
158+
}
131159
}
132160
}

src/A2A/Client/IA2AClient.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ namespace A2A;
44

55
public interface IA2AClient
66
{
7-
Task<A2AResponse> SendMessageAsync(MessageSendParams taskSendParams);
8-
Task<AgentTask> GetTaskAsync(string taskId);
9-
Task<AgentTask> CancelTaskAsync(TaskIdParams taskIdParams);
10-
IAsyncEnumerable<SseItem<A2AEvent>> SendMessageStreamAsync(MessageSendParams taskSendParams);
11-
IAsyncEnumerable<SseItem<A2AEvent>> ResubscribeToTaskAsync(string taskId);
12-
Task<TaskPushNotificationConfig> SetPushNotificationAsync(TaskPushNotificationConfig pushNotificationConfig);
13-
Task<TaskPushNotificationConfig> GetPushNotificationAsync(TaskIdParams taskIdParams);
7+
Task<A2AResponse> SendMessageAsync(MessageSendParams taskSendParams, CancellationToken cancellationToken = default);
8+
Task<AgentTask> GetTaskAsync(string taskId, CancellationToken cancellationToken = default);
9+
Task<AgentTask> CancelTaskAsync(TaskIdParams taskIdParams, CancellationToken cancellationToken = default);
10+
IAsyncEnumerable<SseItem<A2AEvent>> SendMessageStreamAsync(MessageSendParams taskSendParams, CancellationToken cancellationToken = default);
11+
IAsyncEnumerable<SseItem<A2AEvent>> ResubscribeToTaskAsync(string taskId, CancellationToken cancellationToken = default);
12+
Task<TaskPushNotificationConfig> SetPushNotificationAsync(TaskPushNotificationConfig pushNotificationConfig, CancellationToken cancellationToken = default);
13+
Task<TaskPushNotificationConfig> GetPushNotificationAsync(TaskIdParams taskIdParams, CancellationToken cancellationToken = default);
1414
}

0 commit comments

Comments
 (0)