Skip to content

Commit bc2af3e

Browse files
Li Jiecpunion
authored andcommitted
feat(auth): add tool authentication request flow
Add support for tools to request authentication credentials during execution: - Add RequestedAuthConfigs field to EventActions for storing auth requests - Implement GenerateAuthEvent to convert auth requests to adk_request_credential function calls - Add RequestAuthConfig method to ToolContext for tools to request authentication - Add AuthConfig field to Tool interface for tools to declare auth requirements - Integrate auth event generation into the LLM base flow Also fix test comparisons to ignore RequestedAuthConfigs field initialization and add UTC timezone setup in database tests for consistent timestamp formatting.
1 parent ad17675 commit bc2af3e

File tree

14 files changed

+695
-10
lines changed

14 files changed

+695
-10
lines changed

agent/remoteagent/a2a_agent_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,7 @@ func TestRemoteAgent_EmptyResultForEmptySession(t *testing.T) {
756756
cmpopts.IgnoreFields(session.Event{}, "ID"),
757757
cmpopts.IgnoreFields(session.Event{}, "Timestamp"),
758758
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta"),
759+
cmpopts.IgnoreFields(session.EventActions{}, "RequestedAuthConfigs"),
759760
}
760761
if diff := cmp.Diff(wantEvents, gotEvents, ignoreFields...); diff != "" {
761762
t.Fatalf("agent.Run() wrong result (+got,-want):\ngot = %+v\nwant = %+v\ndiff = %s", gotEvents, wantEvents, diff)

agent/remoteagent/utils_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ func TestPresentAsUserMessage(t *testing.T) {
305305
cmpopts.IgnoreFields(session.Event{}, "InvocationID"),
306306
cmpopts.IgnoreFields(session.Event{}, "Timestamp"),
307307
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta"),
308+
cmpopts.IgnoreFields(session.EventActions{}, "RequestedAuthConfigs"),
308309
}
309310
for _, tc := range testCases {
310311
t.Run(tc.name, func(t *testing.T) {

agent/workflowagents/loopagent/agent_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ func TestNewLoopAgent(t *testing.T) {
234234

235235
ignoreFields := []cmp.Option{
236236
cmpopts.IgnoreFields(session.Event{}, "ID", "InvocationID", "Timestamp"),
237-
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta"),
237+
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta", "RequestedAuthConfigs"),
238238
cmpopts.IgnoreFields(genai.FunctionCall{}, "ID"),
239239
cmpopts.IgnoreFields(genai.FunctionResponse{}, "ID"),
240240
}

agent/workflowagents/sequentialagent/agent_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ func TestNewSequentialAgent(t *testing.T) {
254254
for i, gotEvent := range gotEvents {
255255
tt.wantEvents[i].Timestamp = gotEvent.Timestamp
256256
if diff := cmp.Diff(tt.wantEvents[i], gotEvent, cmpopts.IgnoreFields(session.Event{}, "ID", "Timestamp", "InvocationID"),
257-
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta")); diff != "" {
257+
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta", "RequestedAuthConfigs")); diff != "" {
258258
t.Errorf("event[i] mismatch (-want +got):\n%s", diff)
259259
}
260260
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package llminternal
16+
17+
import (
18+
"github.com/google/uuid"
19+
"google.golang.org/genai"
20+
21+
"google.golang.org/adk/agent"
22+
"google.golang.org/adk/auth"
23+
"google.golang.org/adk/model"
24+
"google.golang.org/adk/session"
25+
)
26+
27+
const afFunctionCallIDPrefix = "adk-"
28+
29+
// generateFunctionCallID creates a unique function call ID.
30+
// This matches Python's generate_client_function_call_id() with AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
31+
func generateFunctionCallID() string {
32+
return afFunctionCallIDPrefix + uuid.NewString()
33+
}
34+
35+
// GenerateAuthEvent creates an event with adk_request_credential function calls
36+
// from the RequestedAuthConfigs in the function response event.
37+
// This matches Python ADK's generate_auth_event in flows/llm_flows/functions.py.
38+
func GenerateAuthEvent(ctx agent.InvocationContext, fnResponseEvent *session.Event) *session.Event {
39+
if fnResponseEvent == nil || len(fnResponseEvent.Actions.RequestedAuthConfigs) == 0 {
40+
return nil
41+
}
42+
43+
var parts []*genai.Part
44+
var longRunningToolIDs []string
45+
46+
for functionCallID, authConfig := range fnResponseEvent.Actions.RequestedAuthConfigs {
47+
// Create args map matching Python's AuthToolArguments.model_dump()
48+
// Note: We preserve *auth.AuthConfig pointer since this is in-memory,
49+
// matching Python's behavior where objects are passed by reference.
50+
argsMap := map[string]any{
51+
"function_call_id": functionCallID,
52+
"auth_config": authConfig, // Keep as *auth.AuthConfig pointer
53+
}
54+
55+
// Create the adk_request_credential function call
56+
requestEucFunctionCall := &genai.FunctionCall{
57+
Name: auth.RequestEUCFunctionCallName,
58+
Args: argsMap,
59+
}
60+
61+
// Generate a unique ID for this function call
62+
requestEucFunctionCall.ID = generateFunctionCallID()
63+
longRunningToolIDs = append(longRunningToolIDs, requestEucFunctionCall.ID)
64+
65+
parts = append(parts, &genai.Part{
66+
FunctionCall: requestEucFunctionCall,
67+
})
68+
}
69+
70+
// Determine the role from the original event
71+
role := "model"
72+
if fnResponseEvent.Content != nil && fnResponseEvent.Content.Role != "" {
73+
role = fnResponseEvent.Content.Role
74+
}
75+
76+
// Create the auth event
77+
authEvent := session.NewEvent(ctx.InvocationID())
78+
authEvent.Author = ctx.Agent().Name()
79+
authEvent.Branch = ctx.Branch()
80+
authEvent.LLMResponse = model.LLMResponse{
81+
Content: &genai.Content{
82+
Role: role,
83+
Parts: parts,
84+
},
85+
}
86+
authEvent.LongRunningToolIDs = longRunningToolIDs
87+
88+
return authEvent
89+
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package llminternal
16+
17+
import (
18+
"testing"
19+
20+
"google.golang.org/adk/auth"
21+
contextinternal "google.golang.org/adk/internal/context"
22+
"google.golang.org/adk/session"
23+
)
24+
25+
func TestGenerateAuthEvent_Nil(t *testing.T) {
26+
inv := contextinternal.NewInvocationContext(t.Context(), contextinternal.InvocationContextParams{})
27+
28+
// Nil event
29+
result := GenerateAuthEvent(inv, nil)
30+
if result != nil {
31+
t.Error("GenerateAuthEvent(nil) should return nil")
32+
}
33+
34+
// Empty RequestedAuthConfigs
35+
event := &session.Event{
36+
Actions: session.EventActions{
37+
RequestedAuthConfigs: make(map[string]*auth.AuthConfig),
38+
},
39+
}
40+
result = GenerateAuthEvent(inv, event)
41+
if result != nil {
42+
t.Error("GenerateAuthEvent with empty RequestedAuthConfigs should return nil")
43+
}
44+
}
45+
46+
// Note: TestGenerateAuthEvent_CreatesEvent and TestGenerateAuthEvent_MultipleCalls
47+
// are skipped as they require full invocation context with agent setup.
48+
// The GenerateAuthEvent function is tested indirectly through integration tests.
49+
50+
func TestGenerateFunctionCallID(t *testing.T) {
51+
id1 := generateFunctionCallID()
52+
id2 := generateFunctionCallID()
53+
54+
if id1 == "" {
55+
t.Error("generateFunctionCallID() returned empty string")
56+
}
57+
if id1 == id2 {
58+
t.Error("generateFunctionCallID() should return unique IDs")
59+
}
60+
if len(id1) < 4 || id1[:4] != "adk-" {
61+
t.Errorf("generateFunctionCallID() = %q, should start with 'adk-'", id1)
62+
}
63+
}
64+
65+
func TestParseAuthConfigFromMap(t *testing.T) {
66+
data := map[string]any{
67+
"credential_key": "test-key",
68+
"exchanged_auth_credential": map[string]any{
69+
"auth_type": "oauth2",
70+
"oauth2": map[string]any{
71+
"access_token": "token123",
72+
"refresh_token": "refresh456",
73+
"expires_at": float64(1234567890),
74+
},
75+
},
76+
}
77+
78+
config, err := parseAuthConfigFromMap(data)
79+
if err != nil {
80+
t.Fatalf("parseAuthConfigFromMap() error = %v", err)
81+
}
82+
if config.CredentialKey != "test-key" {
83+
t.Errorf("CredentialKey = %q, want %q", config.CredentialKey, "test-key")
84+
}
85+
if config.ExchangedAuthCredential == nil {
86+
t.Fatal("ExchangedAuthCredential should not be nil")
87+
}
88+
if config.ExchangedAuthCredential.OAuth2 == nil {
89+
t.Fatal("OAuth2 should not be nil")
90+
}
91+
if config.ExchangedAuthCredential.OAuth2.AccessToken != "token123" {
92+
t.Errorf("AccessToken = %q, want %q", config.ExchangedAuthCredential.OAuth2.AccessToken, "token123")
93+
}
94+
}
95+
96+
func TestParseAuthCredentialFromMap(t *testing.T) {
97+
data := map[string]any{
98+
"auth_type": "oauth2",
99+
"oauth2": map[string]any{
100+
"access_token": "access",
101+
"refresh_token": "refresh",
102+
"expires_at": float64(9999999999),
103+
},
104+
}
105+
106+
cred, err := parseAuthCredentialFromMap(data)
107+
if err != nil {
108+
t.Fatalf("parseAuthCredentialFromMap() error = %v", err)
109+
}
110+
if cred.AuthType != auth.AuthCredentialTypeOAuth2 {
111+
t.Errorf("AuthType = %v, want %v", cred.AuthType, auth.AuthCredentialTypeOAuth2)
112+
}
113+
if cred.OAuth2.AccessToken != "access" {
114+
t.Errorf("AccessToken = %q, want %q", cred.OAuth2.AccessToken, "access")
115+
}
116+
if cred.OAuth2.RefreshToken != "refresh" {
117+
t.Errorf("RefreshToken = %q, want %q", cred.OAuth2.RefreshToken, "refresh")
118+
}
119+
}
120+
121+
func TestParseAuthCredentialFromMap_NotAMap(t *testing.T) {
122+
_, err := parseAuthCredentialFromMap("not a map")
123+
if err == nil {
124+
t.Error("parseAuthCredentialFromMap() should error for non-map input")
125+
}
126+
}
127+
128+
func TestAuthPreprocessorResult_Init(t *testing.T) {
129+
// Verify CurrentAuthPreprocessorResult starts as nil
130+
if CurrentAuthPreprocessorResult != nil {
131+
// Clear it for test isolation
132+
CurrentAuthPreprocessorResult = nil
133+
}
134+
135+
result := &AuthPreprocessorResult{
136+
ToolIdsToResume: make(map[string]bool),
137+
CredentialsStored: true,
138+
OriginalEvent: &session.Event{},
139+
}
140+
141+
if result.ToolIdsToResume == nil {
142+
t.Error("ToolIdsToResume should not be nil")
143+
}
144+
if !result.CredentialsStored {
145+
t.Error("CredentialsStored should be true")
146+
}
147+
if result.OriginalEvent == nil {
148+
t.Error("OriginalEvent should not be nil")
149+
}
150+
}

internal/llminternal/base_flow.go

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,38 @@ func (f *Flow) runOneStep(ctx agent.InvocationContext) iter.Seq2[*session.Event,
125125
if ctx.Ended() {
126126
return
127127
}
128+
129+
// Check if auth preprocessor found tools that need to be re-executed.
130+
// This implements the "Surgical Resumption" pattern from Python ADK.
131+
if result := CurrentAuthPreprocessorResult; result != nil && result.OriginalEvent != nil && len(result.ToolIdsToResume) > 0 {
132+
// Clear the result immediately to prevent re-processing
133+
CurrentAuthPreprocessorResult = nil
134+
135+
// Build tools map
136+
tools := make(map[string]tool.Tool)
137+
for k, v := range req.Tools {
138+
if t, ok := v.(tool.Tool); ok {
139+
tools[k] = t
140+
}
141+
}
142+
143+
// Execute function calls from the original event that match our tools_to_resume
144+
// This matches Python's handle_function_calls_async with tools_to_resume filter
145+
fnResponseEvent, err := f.handleFunctionCalls(ctx, tools, &result.OriginalEvent.LLMResponse, result.ToolIdsToResume)
146+
if err != nil {
147+
yield(nil, err)
148+
return
149+
}
150+
if fnResponseEvent != nil {
151+
if !yield(fnResponseEvent, nil) {
152+
return
153+
}
154+
}
155+
156+
// Return after tool re-execution - Python does the same
157+
return
158+
}
159+
128160
spans := telemetry.StartTrace(ctx, "call_llm")
129161
// Create event to pass to callback state delta
130162
stateDelta := make(map[string]any)
@@ -163,10 +195,8 @@ func (f *Flow) runOneStep(ctx agent.InvocationContext) iter.Seq2[*session.Event,
163195
if !yield(modelResponseEvent, nil) {
164196
return
165197
}
166-
// TODO: generate and yield an auth event if needed.
167198

168199
// Handle function calls.
169-
170200
ev, err := f.handleFunctionCalls(ctx, tools, resp)
171201
if err != nil {
172202
yield(nil, err)
@@ -180,6 +210,14 @@ func (f *Flow) runOneStep(ctx agent.InvocationContext) iter.Seq2[*session.Event,
180210
return
181211
}
182212

213+
// Generate and yield an auth event if needed.
214+
// This converts RequestedAuthConfigs into adk_request_credential function calls.
215+
if authEvent := GenerateAuthEvent(ctx, ev); authEvent != nil {
216+
if !yield(authEvent, nil) {
217+
return
218+
}
219+
}
220+
183221
// Actually handle "transfer_to_agent" tool. The function call sets the ev.Actions.TransferToAgent field.
184222
// We are following python's execution flow which is
185223
// BaseLlmFlow._postprocess_async
@@ -364,14 +402,25 @@ func findLongRunningFunctionCallIDs(c *genai.Content, tools map[string]tool.Tool
364402
}
365403

366404
// handleFunctionCalls calls the functions and returns the function response event.
405+
// If toolsToResume is non-nil and non-empty, only function calls with IDs in the map are executed.
367406
//
368-
// TODO: accept filters to include/exclude function calls.
369407
// TODO: check feasibility of running tool.Run concurrently.
370-
func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[string]tool.Tool, resp *model.LLMResponse) (*session.Event, error) {
408+
func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[string]tool.Tool, resp *model.LLMResponse, toolsToResume ...map[string]bool) (*session.Event, error) {
371409
var fnResponseEvents []*session.Event
372410

411+
// Build filter map if provided
412+
var filterMap map[string]bool
413+
if len(toolsToResume) > 0 && toolsToResume[0] != nil && len(toolsToResume[0]) > 0 {
414+
filterMap = toolsToResume[0]
415+
}
416+
373417
fnCalls := utils.FunctionCalls(resp.Content)
374418
for _, fnCall := range fnCalls {
419+
// Skip function calls not in the filter (if filter is provided)
420+
if filterMap != nil && !filterMap[fnCall.ID] {
421+
continue
422+
}
423+
375424
curTool, ok := toolsDict[fnCall.Name]
376425
if !ok {
377426
return nil, fmt.Errorf("unknown tool: %q", fnCall.Name)
@@ -380,6 +429,7 @@ func (f *Flow) handleFunctionCalls(ctx agent.InvocationContext, toolsDict map[st
380429
if !ok {
381430
return nil, fmt.Errorf("tool %q is not a function tool", curTool.Name())
382431
}
432+
383433
toolCtx := toolinternal.NewToolContext(ctx, fnCall.ID, &session.EventActions{StateDelta: make(map[string]any)})
384434
// toolCtx := tool.
385435
spans := telemetry.StartTrace(ctx, "execute_tool "+fnCall.Name)

0 commit comments

Comments
 (0)