Skip to content

Commit f7aa465

Browse files
authored
feat: blocking flag handling (a2aproject#97)
* Add `Polling` option to `a2aclient.Config`. If set to true `Blocking: false` message configuration is sent to the server. * Add `Blocking: false` handling on `a2asrv`. When execution results in a Task, the Task is returned to the caller immediately. * Make `Blocking` optional to match the "default to true if not set" behavior all other SDKs implement. Considered returning immediately when a requests references an existing Task, but decided that'd be confusing to always receive the task in `input-required` immediately after providing a follow-up. re a2aproject#96
1 parent 25b9aae commit f7aa465

File tree

11 files changed

+292
-52
lines changed

11 files changed

+292
-52
lines changed

a2a/core.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,8 @@ type MessageSendConfig struct {
691691
AcceptedOutputModes []string `json:"acceptedOutputModes,omitempty" yaml:"acceptedOutputModes,omitempty" mapstructure:"acceptedOutputModes,omitempty"`
692692

693693
// Blocking indicates if the client will wait for the task to complete. The server may reject
694-
// this if the task is long-running.
695-
Blocking bool `json:"blocking,omitempty" yaml:"blocking,omitempty" mapstructure:"blocking,omitempty"`
694+
// this if the task is long-running. Server might choose to default to true.
695+
Blocking *bool `json:"blocking,omitempty" yaml:"blocking,omitempty" mapstructure:"blocking,omitempty"`
696696

697697
// HistoryLength is the number of most recent messages from the task's history to retrieve in the response.
698698
HistoryLength *int `json:"historyLength,omitempty" yaml:"historyLength,omitempty" mapstructure:"historyLength,omitempty"`

a2aclient/client.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"sync/atomic"
2121

2222
"github.com/a2aproject/a2a-go/a2a"
23+
"github.com/a2aproject/a2a-go/internal/utils"
2324
)
2425

2526
// Config exposes options for customizing Client behavior.
@@ -38,6 +39,10 @@ type Config struct {
3839
// If there's no overlap in supported Transport Factory will return an error on Client
3940
// creation attempt.
4041
PreferredTransports []a2a.TransportProtocol
42+
// Whether client prefers to poll for task updates instead of blocking until a terminal state is reached.
43+
// If set to true, non-streaming send message result might be a Message or a Task in any (including non-terminal) state.
44+
// Callers are responsible for running the polling loop. This configuration does not apply to streaming requests.
45+
Polling bool
4146
}
4247

4348
// Client represents a transport-agnostic implementation of A2A client.
@@ -94,7 +99,7 @@ func (c *Client) CancelTask(ctx context.Context, id *a2a.TaskIDParams) (*a2a.Tas
9499
func (c *Client) SendMessage(ctx context.Context, message *a2a.MessageSendParams) (a2a.SendMessageResult, error) {
95100
method := "SendMessage"
96101

97-
message = c.withDefaultSendConfig(message)
102+
message = c.withDefaultSendConfig(message, blocking(!c.config.Polling))
98103

99104
ctx, err := c.interceptBefore(ctx, method, message)
100105
if err != nil {
@@ -113,7 +118,7 @@ func (c *Client) SendStreamingMessage(ctx context.Context, message *a2a.MessageS
113118
return func(yield func(a2a.Event, error) bool) {
114119
method := "SendStreamingMessage"
115120

116-
message = c.withDefaultSendConfig(message)
121+
message = c.withDefaultSendConfig(message, blocking(true))
117122

118123
ctx, err := c.interceptBefore(ctx, method, message)
119124
if err != nil {
@@ -269,8 +274,10 @@ func (c *Client) Destroy() error {
269274
return c.transport.Destroy()
270275
}
271276

272-
func (c *Client) withDefaultSendConfig(message *a2a.MessageSendParams) *a2a.MessageSendParams {
273-
if c.config.PushConfig == nil && c.config.AcceptedOutputModes == nil {
277+
type blocking bool
278+
279+
func (c *Client) withDefaultSendConfig(message *a2a.MessageSendParams, blocking blocking) *a2a.MessageSendParams {
280+
if c.config.PushConfig == nil && c.config.AcceptedOutputModes == nil && blocking {
274281
return message
275282
}
276283
result := *message
@@ -286,6 +293,7 @@ func (c *Client) withDefaultSendConfig(message *a2a.MessageSendParams) *a2a.Mess
286293
if result.Config.AcceptedOutputModes == nil {
287294
result.Config.AcceptedOutputModes = c.config.AcceptedOutputModes
288295
}
296+
result.Config.Blocking = utils.Ptr(bool(blocking))
289297
return &result
290298
}
291299

a2aclient/client_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"testing"
2323

2424
"github.com/a2aproject/a2a-go/a2a"
25+
"github.com/a2aproject/a2a-go/internal/utils"
2526
"github.com/google/go-cmp/cmp"
2627
)
2728

@@ -175,7 +176,7 @@ func TestClient_DefaultSendMessageConfig(t *testing.T) {
175176
}
176177
interceptor := &testInterceptor{}
177178
client := &Client{
178-
config: Config{PushConfig: pushConfig, AcceptedOutputModes: acceptedModes},
179+
config: Config{PushConfig: pushConfig, AcceptedOutputModes: acceptedModes, Polling: true},
179180
transport: transport,
180181
interceptors: []CallInterceptor{interceptor},
181182
}
@@ -187,7 +188,7 @@ func TestClient_DefaultSendMessageConfig(t *testing.T) {
187188
t.Fatalf("client.SendMessage() error = %v", err)
188189
}
189190
want := &a2a.MessageSendParams{
190-
Config: &a2a.MessageSendConfig{AcceptedOutputModes: acceptedModes, PushConfig: pushConfig},
191+
Config: &a2a.MessageSendConfig{AcceptedOutputModes: acceptedModes, PushConfig: pushConfig, Blocking: utils.Ptr(false)},
191192
}
192193
if diff := cmp.Diff(want, interceptor.lastReq.Payload); diff != "" {
193194
t.Fatalf("client.SendMessage() wrong result (+got,-want) diff = %s", diff)
@@ -214,7 +215,7 @@ func TestClient_DefaultSendStreamingMessageConfig(t *testing.T) {
214215
}
215216
interceptor := &testInterceptor{}
216217
client := &Client{
217-
config: Config{PushConfig: pushConfig, AcceptedOutputModes: acceptedModes},
218+
config: Config{PushConfig: pushConfig, AcceptedOutputModes: acceptedModes, Polling: true},
218219
transport: transport,
219220
interceptors: []CallInterceptor{interceptor},
220221
}
@@ -225,7 +226,7 @@ func TestClient_DefaultSendStreamingMessageConfig(t *testing.T) {
225226
}
226227
}
227228
want := &a2a.MessageSendParams{
228-
Config: &a2a.MessageSendConfig{AcceptedOutputModes: acceptedModes, PushConfig: pushConfig},
229+
Config: &a2a.MessageSendConfig{AcceptedOutputModes: acceptedModes, PushConfig: pushConfig, Blocking: utils.Ptr(true)},
229230
}
230231
if diff := cmp.Diff(want, interceptor.lastReq.Payload); diff != "" {
231232
t.Fatalf("client.SendStreamingMessage() wrong result (+got,-want) diff = %s", diff)

a2apb/pbconv/from_proto.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919

2020
"github.com/a2aproject/a2a-go/a2a"
2121
"github.com/a2aproject/a2a-go/a2apb"
22+
"google.golang.org/protobuf/proto"
2223
structpb "google.golang.org/protobuf/types/known/structpb"
2324
)
2425

@@ -162,7 +163,7 @@ func fromProtoSendMessageConfig(conf *a2apb.SendMessageConfiguration) (*a2a.Mess
162163

163164
result := &a2a.MessageSendConfig{
164165
AcceptedOutputModes: conf.GetAcceptedOutputModes(),
165-
Blocking: conf.GetBlocking(),
166+
Blocking: proto.Bool(conf.GetBlocking()),
166167
PushConfig: pConf,
167168
}
168169

a2apb/pbconv/from_proto_test.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020

2121
"github.com/a2aproject/a2a-go/a2a"
2222
"github.com/a2aproject/a2a-go/a2apb"
23+
"github.com/google/go-cmp/cmp"
24+
"google.golang.org/protobuf/proto"
2325
"google.golang.org/protobuf/types/known/structpb"
2426
)
2527

@@ -187,7 +189,7 @@ func TestFromProto_fromProtoSendMessageConfig(t *testing.T) {
187189
},
188190
want: &a2a.MessageSendConfig{
189191
AcceptedOutputModes: []string{"text/plain"},
190-
Blocking: true,
192+
Blocking: proto.Bool(true),
191193
HistoryLength: &a2aHistoryLen,
192194
PushConfig: &a2a.PushConfig{
193195
ID: "test-push-config",
@@ -205,14 +207,14 @@ func TestFromProto_fromProtoSendMessageConfig(t *testing.T) {
205207
in: &a2apb.SendMessageConfiguration{
206208
HistoryLength: 0,
207209
},
208-
want: &a2a.MessageSendConfig{},
210+
want: &a2a.MessageSendConfig{Blocking: proto.Bool(false)},
209211
},
210212
{
211213
name: "config with no push notification",
212214
in: &a2apb.SendMessageConfiguration{
213215
PushNotification: nil,
214216
},
215-
want: &a2a.MessageSendConfig{},
217+
want: &a2a.MessageSendConfig{Blocking: proto.Bool(false)},
216218
},
217219
{
218220
name: "nil config",
@@ -221,13 +223,15 @@ func TestFromProto_fromProtoSendMessageConfig(t *testing.T) {
221223
},
222224
}
223225
for _, tt := range tests {
224-
got, err := fromProtoSendMessageConfig(tt.in)
225-
if (err != nil) != tt.wantErr {
226-
t.Fatalf("fromProtoSendMessageConfig() error = %v", err)
227-
}
228-
if !reflect.DeepEqual(got, tt.want) {
229-
t.Errorf("fromProtoSendMessageConfig() got = %v, want %v", got, tt.want)
230-
}
226+
t.Run(tt.name, func(t *testing.T) {
227+
got, err := fromProtoSendMessageConfig(tt.in)
228+
if (err != nil) != tt.wantErr {
229+
t.Fatalf("fromProtoSendMessageConfig() error = %v", err)
230+
}
231+
if diff := cmp.Diff(tt.want, got); diff != "" {
232+
t.Errorf("fromProtoSendMessageConfig() wrong result (+got,-want)\ngot = %v\n want %v\ndiff = %s", got, tt.want, diff)
233+
}
234+
})
231235
}
232236
}
233237

@@ -259,7 +263,7 @@ func TestFromProto_fromProtoSendMessageRequest(t *testing.T) {
259263
},
260264
}
261265
a2aConf := &a2a.MessageSendConfig{
262-
Blocking: true,
266+
Blocking: proto.Bool(true),
263267
HistoryLength: &a2aHistoryLen,
264268
PushConfig: &a2a.PushConfig{
265269
ID: "push-config",
@@ -324,7 +328,7 @@ func TestFromProto_fromProtoSendMessageRequest(t *testing.T) {
324328
},
325329
want: &a2a.MessageSendParams{
326330
Message: &a2aMsg,
327-
Config: &a2a.MessageSendConfig{PushConfig: &a2a.PushConfig{}},
331+
Config: &a2a.MessageSendConfig{PushConfig: &a2a.PushConfig{}, Blocking: proto.Bool(false)},
328332
},
329333
},
330334
}

a2apb/pbconv/to_proto.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,11 @@ func toProtoSendMessageConfig(config *a2a.MessageSendConfig) (*a2apb.SendMessage
9393

9494
pConf := &a2apb.SendMessageConfiguration{
9595
AcceptedOutputModes: config.AcceptedOutputModes,
96-
Blocking: config.Blocking,
9796
PushNotification: pushConf,
9897
}
98+
if config.Blocking != nil {
99+
pConf.Blocking = *config.Blocking
100+
}
99101
if config.HistoryLength != nil {
100102
pConf.HistoryLength = int32(*config.HistoryLength)
101103
}

a2asrv/handler.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,11 @@ func (h *defaultRequestHandler) OnSendMessage(ctx context.Context, params *a2a.M
214214
if err != nil {
215215
return nil, err
216216
}
217-
if taskID, required := isAuthRequired(event); required {
217+
218+
if taskID, interrupt := shouldInterruptNonStreaming(params, event); interrupt {
218219
task, err := h.taskStore.Get(ctx, taskID)
219220
if err != nil {
220-
return nil, fmt.Errorf("failed to load task in auth-required state: %w", err)
221+
return nil, fmt.Errorf("failed to load task on event processing interrupt: %w", err)
221222
}
222223
return task, nil
223224
}
@@ -349,12 +350,23 @@ func (h *defaultRequestHandler) OnGetExtendedAgentCard(ctx context.Context) (*a2
349350
return h.authenticatedCardProducer.Card(ctx)
350351
}
351352

352-
func isAuthRequired(event a2a.Event) (a2a.TaskID, bool) {
353+
func shouldInterruptNonStreaming(params *a2a.MessageSendParams, event a2a.Event) (a2a.TaskID, bool) {
354+
// Non-blocking clients receive a result on the first task event, default Blocking to TRUE
355+
if params.Config != nil && params.Config.Blocking != nil && !(*params.Config.Blocking) {
356+
if _, ok := event.(*a2a.Message); ok {
357+
return "", false
358+
}
359+
taskInfo := event.TaskInfo()
360+
return taskInfo.TaskID, true
361+
}
362+
363+
// Non-streaming clients need to be notified when auth is required
353364
switch v := event.(type) {
354365
case *a2a.Task:
355366
return v.ID, v.Status.State == a2a.TaskStateAuthRequired
356367
case *a2a.TaskStatusUpdateEvent:
357368
return v.TaskID, v.Status.State == a2a.TaskStateAuthRequired
358369
}
370+
359371
return "", false
360372
}

0 commit comments

Comments
 (0)