Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agent/remoteagent/a2a_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ func TestRemoteAgent_EmptyResultForEmptySession(t *testing.T) {
cmpopts.IgnoreFields(session.Event{}, "ID"),
cmpopts.IgnoreFields(session.Event{}, "Timestamp"),
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta"),
cmpopts.IgnoreFields(session.EventActions{}, "RequestedAuthConfigs"),
}
if diff := cmp.Diff(wantEvents, gotEvents, ignoreFields...); diff != "" {
t.Fatalf("agent.Run() wrong result (+got,-want):\ngot = %+v\nwant = %+v\ndiff = %s", gotEvents, wantEvents, diff)
Expand Down
1 change: 1 addition & 0 deletions agent/remoteagent/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ func TestPresentAsUserMessage(t *testing.T) {
cmpopts.IgnoreFields(session.Event{}, "InvocationID"),
cmpopts.IgnoreFields(session.Event{}, "Timestamp"),
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta"),
cmpopts.IgnoreFields(session.EventActions{}, "RequestedAuthConfigs"),
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion agent/workflowagents/loopagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func TestNewLoopAgent(t *testing.T) {

ignoreFields := []cmp.Option{
cmpopts.IgnoreFields(session.Event{}, "ID", "InvocationID", "Timestamp"),
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta"),
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta", "RequestedAuthConfigs"),
cmpopts.IgnoreFields(genai.FunctionCall{}, "ID"),
cmpopts.IgnoreFields(genai.FunctionResponse{}, "ID"),
}
Expand Down
2 changes: 1 addition & 1 deletion agent/workflowagents/sequentialagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func TestNewSequentialAgent(t *testing.T) {
for i, gotEvent := range gotEvents {
tt.wantEvents[i].Timestamp = gotEvent.Timestamp
if diff := cmp.Diff(tt.wantEvents[i], gotEvent, cmpopts.IgnoreFields(session.Event{}, "ID", "Timestamp", "InvocationID"),
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta")); diff != "" {
cmpopts.IgnoreFields(session.EventActions{}, "StateDelta", "RequestedAuthConfigs")); diff != "" {
t.Errorf("event[i] mismatch (-want +got):\n%s", diff)
}
}
Expand Down
160 changes: 160 additions & 0 deletions auth/auth_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package auth

import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"sort"
"strconv"
)

// AuthConfig combines auth scheme and credentials for a tool.
// This is passed to tools that require authentication.
type AuthConfig struct {
// AuthScheme defines how the API expects authentication.
AuthScheme AuthScheme `json:"authScheme"`
// RawAuthCredential is the initial credential (e.g., client_id/secret).
RawAuthCredential *AuthCredential `json:"rawAuthCredential,omitempty"`
// ExchangedAuthCredential is the processed credential (e.g., access_token).
ExchangedAuthCredential *AuthCredential `json:"exchangedAuthCredential,omitempty"`
// CredentialKey is a unique key for persisting this credential.
CredentialKey string `json:"credentialKey,omitempty"`
}

// NewAuthConfig creates a new AuthConfig with the given scheme and credential.
// If credentialKey is empty, it will be generated automatically.
func NewAuthConfig(scheme AuthScheme, credential *AuthCredential) (*AuthConfig, error) {
cfg := &AuthConfig{
AuthScheme: scheme,
RawAuthCredential: credential,
}
if cfg.CredentialKey == "" {
key, err := cfg.generateCredentialKey()
if err != nil {
return nil, fmt.Errorf("generate credential key: %w", err)
}
cfg.CredentialKey = key
}
return cfg, nil
}

// generateCredentialKey creates a unique key based on auth scheme and credential.
func (c *AuthConfig) generateCredentialKey() (string, error) {
var schemePart, credPart string
if c.AuthScheme != nil {
schemeJSON, err := stableJSON(c.AuthScheme)
if err != nil {
return "", fmt.Errorf("marshal auth scheme: %w", err)
}
schemeType := c.AuthScheme.GetType()
h := sha256.Sum256([]byte(schemeJSON))
schemePart = fmt.Sprintf("%s_%x", schemeType, h[:8])
}
if c.RawAuthCredential != nil {
credJSON, err := stableJSON(c.RawAuthCredential)
if err != nil {
return "", fmt.Errorf("marshal auth credential: %w", err)
}
h := sha256.Sum256([]byte(credJSON))
credPart = fmt.Sprintf("%s_%x", c.RawAuthCredential.AuthType, h[:8])
}
return fmt.Sprintf("adk_%s_%s", schemePart, credPart), nil
}

// Copy creates a deep copy of the AuthConfig.
func (c *AuthConfig) Copy() *AuthConfig {
if c == nil {
return nil
}
return &AuthConfig{
AuthScheme: c.AuthScheme, // AuthScheme is typically immutable
RawAuthCredential: c.RawAuthCredential.Copy(),
ExchangedAuthCredential: c.ExchangedAuthCredential.Copy(),
CredentialKey: c.CredentialKey,
}
}

// stableJSON returns a deterministic JSON representation with sorted map keys.
func stableJSON(v interface{}) (string, error) {
raw, err := json.Marshal(v)
if err != nil {
return "", err
}
var data interface{}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
if err := dec.Decode(&data); err != nil {
return "", err
}
var buf bytes.Buffer
if err := encodeCanonical(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}

func encodeCanonical(buf *bytes.Buffer, v interface{}) error {
switch val := v.(type) {
case nil:
buf.WriteString("null")
case bool:
if val {
buf.WriteString("true")
} else {
buf.WriteString("false")
}
case string:
buf.WriteString(strconv.Quote(val))
case json.Number:
buf.WriteString(val.String())
case float64:
buf.WriteString(strconv.FormatFloat(val, 'g', -1, 64))
case []interface{}:
buf.WriteByte('[')
for i, elem := range val {
if i > 0 {
buf.WriteByte(',')
}
if err := encodeCanonical(buf, elem); err != nil {
return err
}
}
buf.WriteByte(']')
case map[string]interface{}:
buf.WriteByte('{')
keys := make([]string, 0, len(val))
for k := range val {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(strconv.Quote(k))
buf.WriteByte(':')
if err := encodeCanonical(buf, val[k]); err != nil {
return err
}
}
buf.WriteByte('}')
default:
return fmt.Errorf("unsupported JSON canonicalization type %T", v)
}
return nil
}
Loading