Skip to content

Commit 25b9aae

Browse files
authored
feat: client auth interceptor (a2aproject#90)
* `AuthInterceptor` similar to what [python provides](https://github.com/a2aproject/a2a-python/blob/main/src/a2a/client/auth/interceptor.py). * Metadata propagation tests.
1 parent ee0e7ed commit 25b9aae

File tree

5 files changed

+539
-176
lines changed

5 files changed

+539
-176
lines changed

a2aclient/auth.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ package a2aclient
1717
import (
1818
"context"
1919
"errors"
20+
"fmt"
2021
"sync"
2122

2223
"github.com/a2aproject/a2a-go/a2a"
24+
"github.com/a2aproject/a2a-go/log"
2325
)
2426

2527
// ErrCredentialNotFound is returned by CredentialsService if a credential for the provided
@@ -56,6 +58,44 @@ type AuthInterceptor struct {
5658
Service CredentialsService
5759
}
5860

61+
func (ai *AuthInterceptor) Before(ctx context.Context, req *Request) (context.Context, error) {
62+
if req.Card == nil || req.Card.Security == nil || req.Card.SecuritySchemes == nil {
63+
return ctx, nil
64+
}
65+
66+
sessionID, ok := SessionIDFrom(ctx)
67+
if !ok {
68+
return ctx, nil
69+
}
70+
71+
for _, requirement := range req.Card.Security {
72+
for schemeName := range requirement {
73+
credential, err := ai.Service.Get(ctx, sessionID, schemeName)
74+
if errors.Is(err, ErrCredentialNotFound) {
75+
continue
76+
}
77+
if err != nil {
78+
log.Error(ctx, "credentials service error", err)
79+
continue
80+
}
81+
scheme, ok := req.Card.SecuritySchemes[schemeName]
82+
if !ok {
83+
continue
84+
}
85+
switch v := scheme.(type) {
86+
case a2a.HTTPAuthSecurityScheme, a2a.OAuth2SecurityScheme:
87+
req.Meta["Authorization"] = []string{fmt.Sprintf("Bearer %s", credential)}
88+
return ctx, nil
89+
case a2a.APIKeySecurityScheme:
90+
req.Meta[v.Name] = []string{string(credential)}
91+
return ctx, nil
92+
}
93+
}
94+
}
95+
96+
return ctx, nil
97+
}
98+
5999
// CredentialsService is used by auth interceptor for resolving credentials.
60100
type CredentialsService interface {
61101
Get(ctx context.Context, sid SessionID, scheme a2a.SecuritySchemeName) (AuthCredential, error)
@@ -70,9 +110,11 @@ type InMemoryCredentialsStore struct {
70110
credentials map[SessionID]SessionCredentials
71111
}
72112

113+
var _ CredentialsService = (*InMemoryCredentialsStore)(nil)
114+
73115
// NewInMemoryCredentialsStore initializes an InMemoryCredentialsStore.
74-
func NewInMemoryCredentialsStore() InMemoryCredentialsStore {
75-
return InMemoryCredentialsStore{
116+
func NewInMemoryCredentialsStore() *InMemoryCredentialsStore {
117+
return &InMemoryCredentialsStore{
76118
credentials: make(map[SessionID]SessionCredentials),
77119
}
78120
}

a2aclient/auth_test.go

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
// Copyright 2025 The A2A Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package a2aclient
16+
17+
import (
18+
"context"
19+
"errors"
20+
"net"
21+
"testing"
22+
23+
"github.com/a2aproject/a2a-go/a2a"
24+
"github.com/a2aproject/a2a-go/a2agrpc"
25+
"github.com/a2aproject/a2a-go/a2asrv"
26+
"github.com/a2aproject/a2a-go/a2asrv/eventqueue"
27+
"github.com/google/go-cmp/cmp"
28+
"google.golang.org/grpc"
29+
"google.golang.org/grpc/credentials/insecure"
30+
"google.golang.org/grpc/test/bufconn"
31+
)
32+
33+
type mockAgentExecutor struct {
34+
ExecuteFn func(context.Context, *a2asrv.RequestContext, eventqueue.Queue) error
35+
}
36+
37+
func (e *mockAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestContext, q eventqueue.Queue) error {
38+
if e.ExecuteFn != nil {
39+
return e.ExecuteFn(ctx, reqCtx, q)
40+
}
41+
return nil
42+
}
43+
44+
func (e *mockAgentExecutor) Cancel(ctx context.Context, reqCtx *a2asrv.RequestContext, q eventqueue.Queue) error {
45+
return nil
46+
}
47+
48+
func startGRPCTestServer(t *testing.T, handler a2asrv.RequestHandler, listener *bufconn.Listener) {
49+
s := grpc.NewServer()
50+
grpcHandler := a2agrpc.NewHandler(handler)
51+
grpcHandler.RegisterWith(s)
52+
if err := s.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
53+
t.Logf("Server exited with error: %v", err)
54+
}
55+
}
56+
57+
func withTestGRPCTransport(listener *bufconn.Listener) FactoryOption {
58+
return WithGRPCTransport(
59+
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
60+
return listener.Dial()
61+
}),
62+
grpc.WithTransportCredentials(insecure.NewCredentials()),
63+
)
64+
}
65+
66+
func TestAuth_GRPC(t *testing.T) {
67+
ctx := t.Context()
68+
listener := bufconn.Listen(1024 * 1024)
69+
70+
var capturedCallContext *a2asrv.CallContext
71+
executor := &mockAgentExecutor{
72+
ExecuteFn: func(ctx context.Context, reqCtx *a2asrv.RequestContext, q eventqueue.Queue) error {
73+
capturedCallContext, _ = a2asrv.CallContextFrom(ctx)
74+
return q.Write(ctx, a2a.NewMessage(a2a.MessageRoleAgent))
75+
},
76+
}
77+
handler := a2asrv.NewHandler(executor)
78+
go startGRPCTestServer(t, handler, listener)
79+
80+
schemeName := a2a.SecuritySchemeName("oauth2")
81+
card := &a2a.AgentCard{
82+
PreferredTransport: a2a.TransportProtocolGRPC,
83+
URL: "passthrough:///bufnet",
84+
Security: []a2a.SecurityRequirements{{schemeName: []string{}}},
85+
SecuritySchemes: a2a.NamedSecuritySchemes{
86+
schemeName: a2a.OAuth2SecurityScheme{},
87+
},
88+
}
89+
90+
credStore := NewInMemoryCredentialsStore()
91+
client, err := NewFromCard(
92+
ctx,
93+
card,
94+
withTestGRPCTransport(listener),
95+
WithInterceptors(&AuthInterceptor{Service: credStore}),
96+
)
97+
if err != nil {
98+
t.Fatalf("a2aclient.NewFromCard() error = %v", err)
99+
}
100+
101+
token := "secret"
102+
sessionID := SessionID("abcd")
103+
credStore.Set(sessionID, schemeName, AuthCredential(token))
104+
105+
ctx = WithSessionID(ctx, sessionID)
106+
_, err = client.SendMessage(ctx, &a2a.MessageSendParams{Message: a2a.NewMessage(a2a.MessageRoleUser)})
107+
if err != nil {
108+
t.Fatalf("client.SendMessage() error = %v", err)
109+
}
110+
111+
auth, _ := capturedCallContext.RequestMeta().Get("authorization")
112+
if diff := cmp.Diff([]string{"Bearer " + token}, auth); diff != "" {
113+
t.Fatalf("RequestMeta[authorization] wrong value = %v, want = %v", auth, []string{"Bearer " + token})
114+
}
115+
}
116+
117+
func TestAuthInterceptor(t *testing.T) {
118+
type storedCred struct {
119+
sid SessionID
120+
scheme a2a.SecuritySchemeName
121+
cred AuthCredential
122+
}
123+
124+
toSchemeName := func(s string) a2a.SecuritySchemeName { return a2a.SecuritySchemeName(s) }
125+
126+
testCases := []struct {
127+
name string
128+
sid SessionID
129+
stored []*storedCred
130+
card *a2a.AgentCard
131+
want CallMeta
132+
}{
133+
{
134+
name: "http auth",
135+
sid: SessionID("123"),
136+
stored: []*storedCred{{
137+
sid: SessionID("123"),
138+
scheme: toSchemeName("test"),
139+
cred: AuthCredential("secret"),
140+
}},
141+
card: &a2a.AgentCard{
142+
Security: []a2a.SecurityRequirements{{toSchemeName("test"): []string{}}},
143+
SecuritySchemes: a2a.NamedSecuritySchemes{
144+
toSchemeName("test"): a2a.HTTPAuthSecurityScheme{},
145+
},
146+
},
147+
want: CallMeta{"Authorization": []string{"Bearer secret"}},
148+
},
149+
{
150+
name: "ouath2",
151+
sid: SessionID("123"),
152+
stored: []*storedCred{{
153+
sid: SessionID("123"),
154+
scheme: toSchemeName("test"),
155+
cred: AuthCredential("secret"),
156+
}},
157+
card: &a2a.AgentCard{
158+
Security: []a2a.SecurityRequirements{{toSchemeName("test"): []string{}}},
159+
SecuritySchemes: a2a.NamedSecuritySchemes{
160+
toSchemeName("test"): a2a.OAuth2SecurityScheme{},
161+
},
162+
},
163+
want: CallMeta{"Authorization": []string{"Bearer secret"}},
164+
},
165+
{
166+
name: "api key",
167+
sid: SessionID("123"),
168+
stored: []*storedCred{{
169+
sid: SessionID("123"),
170+
scheme: toSchemeName("test"),
171+
cred: AuthCredential("secret"),
172+
}},
173+
card: &a2a.AgentCard{
174+
Security: []a2a.SecurityRequirements{{toSchemeName("test"): []string{}}},
175+
SecuritySchemes: a2a.NamedSecuritySchemes{
176+
toSchemeName("test"): a2a.APIKeySecurityScheme{Name: "X-Custom-Auth"},
177+
},
178+
},
179+
want: CallMeta{"X-Custom-Auth": []string{"secret"}},
180+
},
181+
{
182+
name: "first credential chosen",
183+
sid: SessionID("123"),
184+
stored: []*storedCred{
185+
{
186+
sid: SessionID("123"),
187+
scheme: toSchemeName("test-2"),
188+
cred: AuthCredential("secret-2"),
189+
},
190+
{
191+
sid: SessionID("123"),
192+
scheme: toSchemeName("test-3"),
193+
cred: AuthCredential("secret-3"),
194+
},
195+
},
196+
card: &a2a.AgentCard{
197+
Security: []a2a.SecurityRequirements{
198+
{toSchemeName("test"): []string{}},
199+
{toSchemeName("test-2"): []string{}},
200+
{toSchemeName("test-3"): []string{}},
201+
},
202+
SecuritySchemes: a2a.NamedSecuritySchemes{
203+
toSchemeName("test"): a2a.OAuth2SecurityScheme{},
204+
toSchemeName("test-2"): a2a.HTTPAuthSecurityScheme{},
205+
toSchemeName("test-3"): a2a.APIKeySecurityScheme{Name: "X-Custom-Auth"},
206+
},
207+
},
208+
want: CallMeta{"Authorization": []string{"Bearer secret-2"}},
209+
},
210+
{
211+
name: "no session",
212+
card: &a2a.AgentCard{
213+
Security: []a2a.SecurityRequirements{{toSchemeName("test"): []string{}}},
214+
SecuritySchemes: a2a.NamedSecuritySchemes{
215+
toSchemeName("test"): a2a.APIKeySecurityScheme{Name: "X-Custom-Auth"},
216+
},
217+
},
218+
want: CallMeta{},
219+
},
220+
{
221+
name: "different session",
222+
sid: SessionID("123"),
223+
stored: []*storedCred{{
224+
sid: SessionID("321"),
225+
scheme: toSchemeName("test"),
226+
cred: AuthCredential("secret"),
227+
}},
228+
card: &a2a.AgentCard{
229+
Security: []a2a.SecurityRequirements{{toSchemeName("test"): []string{}}},
230+
SecuritySchemes: a2a.NamedSecuritySchemes{
231+
toSchemeName("test"): a2a.APIKeySecurityScheme{Name: "X-Custom-Auth"},
232+
},
233+
},
234+
want: CallMeta{},
235+
},
236+
{
237+
name: "no card",
238+
sid: SessionID("123"),
239+
stored: []*storedCred{{
240+
sid: SessionID("123"),
241+
scheme: toSchemeName("test"),
242+
cred: AuthCredential("secret"),
243+
}},
244+
want: CallMeta{},
245+
},
246+
{
247+
name: "no matching credential",
248+
sid: SessionID("123"),
249+
stored: []*storedCred{{
250+
sid: SessionID("123"),
251+
scheme: toSchemeName("test-2"),
252+
cred: AuthCredential("secret"),
253+
}},
254+
card: &a2a.AgentCard{
255+
Security: []a2a.SecurityRequirements{{toSchemeName("test"): []string{}}},
256+
SecuritySchemes: a2a.NamedSecuritySchemes{
257+
toSchemeName("test"): a2a.OAuth2SecurityScheme{},
258+
},
259+
},
260+
want: CallMeta{},
261+
},
262+
{
263+
name: "no security requirements",
264+
sid: SessionID("123"),
265+
stored: []*storedCred{{
266+
sid: SessionID("123"),
267+
scheme: toSchemeName("test"),
268+
cred: AuthCredential("secret"),
269+
}},
270+
card: &a2a.AgentCard{
271+
SecuritySchemes: a2a.NamedSecuritySchemes{
272+
toSchemeName("test"): a2a.OAuth2SecurityScheme{},
273+
},
274+
},
275+
want: CallMeta{},
276+
},
277+
{
278+
name: "no security schemes",
279+
sid: SessionID("123"),
280+
stored: []*storedCred{{
281+
sid: SessionID("123"),
282+
scheme: toSchemeName("test"),
283+
cred: AuthCredential("secret"),
284+
}},
285+
card: &a2a.AgentCard{
286+
Security: []a2a.SecurityRequirements{{toSchemeName("test"): []string{}}},
287+
},
288+
want: CallMeta{},
289+
},
290+
}
291+
292+
for _, tc := range testCases {
293+
t.Run(tc.name, func(t *testing.T) {
294+
callMeta := CallMeta{}
295+
296+
ctx := t.Context()
297+
if tc.sid != "" {
298+
ctx = WithSessionID(ctx, tc.sid)
299+
}
300+
301+
credStore := NewInMemoryCredentialsStore()
302+
for _, stored := range tc.stored {
303+
credStore.Set(stored.sid, stored.scheme, stored.cred)
304+
}
305+
306+
interceptor := &AuthInterceptor{Service: credStore}
307+
_, err := interceptor.Before(ctx, &Request{Meta: callMeta, Card: tc.card})
308+
if err != nil {
309+
t.Errorf("interceptor.Before() error = %v", err)
310+
}
311+
312+
if diff := cmp.Diff(tc.want, callMeta); diff != "" {
313+
t.Errorf("wrong CallMeta (+got,-want) diff = %s", diff)
314+
}
315+
})
316+
}
317+
}

0 commit comments

Comments
 (0)