Skip to content

Commit 708a68b

Browse files
committed
add did wba http authentication header function
1 parent 3cba041 commit 708a68b

File tree

5 files changed

+363
-175
lines changed

5 files changed

+363
-175
lines changed

agent_connect/authentication/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
verify_auth_header_signature, \
99
extract_auth_header_parts
1010

11+
from .did_wba_auth_header import DIDWbaAuthHeader
12+
1113
# Define what should be exported when using "from agent_connect.authentication import *"
1214
__all__ = ['DIDAllClient', \
1315
'create_did_wba_document', \
1416
'resolve_did_wba_document', \
1517
'resolve_did_wba_document_sync', \
1618
'generate_auth_header', \
1719
'verify_auth_header_signature', \
18-
'extract_auth_header_parts']
20+
'extract_auth_header_parts', \
21+
'DIDWbaAuthHeader']
1922

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# AgentConnect: https://github.com/agent-network-protocol/AgentConnect
2+
# Author: GaoWei Chang
3+
# Email: chgaowei@gmail.com
4+
# Website: https://agent-network-protocol.com/
5+
#
6+
# This project is open-sourced under the MIT License. For details, please see the LICENSE file.
7+
8+
import os
9+
import json
10+
import logging
11+
import asyncio
12+
import aiohttp
13+
from pathlib import Path
14+
from typing import Dict, Optional
15+
from cryptography.hazmat.primitives.asymmetric import ec
16+
from cryptography.hazmat.primitives import serialization
17+
from cryptography.hazmat.primitives import hashes
18+
from urllib.parse import urlparse
19+
20+
# Import agent_connect for DID authentication
21+
from .did_wba import (
22+
generate_auth_header
23+
)
24+
25+
class DIDWbaAuthHeader:
26+
"""
27+
Simplified DID authentication client providing HTTP authentication headers.
28+
"""
29+
30+
def __init__(self, did_document_path: str, private_key_path: str):
31+
"""
32+
Initialize the DID authentication client.
33+
34+
Args:
35+
did_document_path: Path to the DID document (absolute or relative path)
36+
private_key_path: Path to the private key (absolute or relative path)
37+
"""
38+
self.did_document_path = did_document_path
39+
self.private_key_path = private_key_path
40+
41+
# State variables
42+
self.did_document = None
43+
self.auth_headers = {} # Store DID authentication headers by domain
44+
self.tokens = {} # Store tokens by domain
45+
46+
logging.info("DIDWbaAuthHeader initialized")
47+
48+
def _get_domain(self, server_url: str) -> str:
49+
"""Extract domain from URL"""
50+
parsed_url = urlparse(server_url)
51+
domain = parsed_url.netloc.split(':')[0]
52+
return domain
53+
54+
def _load_did_document(self) -> Dict:
55+
"""Load DID document"""
56+
try:
57+
if self.did_document:
58+
return self.did_document
59+
60+
# Use the provided path directly, without resolving absolute path
61+
did_path = self.did_document_path
62+
63+
with open(did_path, 'r') as f:
64+
did_document = json.load(f)
65+
66+
self.did_document = did_document
67+
logging.info(f"Loaded DID document: {did_path}")
68+
return did_document
69+
except Exception as e:
70+
logging.error(f"Error loading DID document: {e}")
71+
raise
72+
73+
def _load_private_key(self) -> ec.EllipticCurvePrivateKey:
74+
"""Load private key"""
75+
try:
76+
# Use the provided path directly, without resolving absolute path
77+
key_path = self.private_key_path
78+
79+
with open(key_path, 'rb') as f:
80+
private_key_data = f.read()
81+
82+
private_key = serialization.load_pem_private_key(
83+
private_key_data,
84+
password=None
85+
)
86+
87+
logging.debug(f"Loaded private key: {key_path}")
88+
return private_key
89+
except Exception as e:
90+
logging.error(f"Error loading private key: {e}")
91+
raise
92+
93+
def _sign_callback(self, content: bytes, method_fragment: str) -> bytes:
94+
"""Sign callback function"""
95+
try:
96+
private_key = self._load_private_key()
97+
signature = private_key.sign(
98+
content,
99+
ec.ECDSA(hashes.SHA256())
100+
)
101+
102+
logging.debug(f"Signed content with method fragment: {method_fragment}")
103+
return signature
104+
except Exception as e:
105+
logging.error(f"Error signing content: {e}")
106+
raise
107+
108+
def _generate_auth_header(self, domain: str) -> str:
109+
"""Generate DID authentication header"""
110+
try:
111+
did_document = self._load_did_document()
112+
113+
auth_header = generate_auth_header(
114+
did_document,
115+
domain,
116+
self._sign_callback
117+
)
118+
119+
logging.info(f"Generated authentication header for domain {domain}: {auth_header[:30]}...")
120+
return auth_header
121+
except Exception as e:
122+
logging.error(f"Error generating authentication header: {e}")
123+
raise
124+
125+
def get_auth_header(self, server_url: str, force_new: bool = False) -> Dict[str, str]:
126+
"""
127+
Get authentication header.
128+
129+
Args:
130+
server_url: Server URL
131+
force_new: Whether to force generate a new DID authentication header
132+
133+
Returns:
134+
Dict[str, str]: HTTP header dictionary
135+
"""
136+
domain = self._get_domain(server_url)
137+
138+
# If there is a token and not forcing a new authentication header, return the token
139+
if domain in self.tokens and not force_new:
140+
token = self.tokens[domain]
141+
logging.info(f"Using existing token for domain {domain}")
142+
return {"Authorization": f"Bearer {token}"}
143+
144+
# Otherwise, generate or use existing DID authentication header
145+
if domain not in self.auth_headers or force_new:
146+
self.auth_headers[domain] = self._generate_auth_header(domain)
147+
148+
logging.info(f"Using DID authentication header for domain {domain}")
149+
return {"Authorization": self.auth_headers[domain]}
150+
151+
def update_token(self, server_url: str, headers: Dict[str, str]) -> Optional[str]:
152+
"""
153+
Update token from response headers.
154+
155+
Args:
156+
server_url: Server URL
157+
headers: Response header dictionary
158+
159+
Returns:
160+
Optional[str]: Updated token, or None if no valid token is found
161+
"""
162+
domain = self._get_domain(server_url)
163+
auth_header = headers.get("Authorization")
164+
165+
if auth_header and auth_header.lower().startswith("bearer "):
166+
token = auth_header[7:] # Remove "Bearer " prefix
167+
self.tokens[domain] = token
168+
logging.info(f"Updated token for domain {domain}: {token[:30]}...")
169+
return token
170+
else:
171+
logging.debug(f"No valid token found in response headers for domain {domain}")
172+
return None
173+
174+
def clear_token(self, server_url: str) -> None:
175+
"""
176+
Clear token for the specified domain.
177+
178+
Args:
179+
server_url: Server URL
180+
"""
181+
domain = self._get_domain(server_url)
182+
if domain in self.tokens:
183+
del self.tokens[domain]
184+
logging.info(f"Cleared token for domain {domain}")
185+
else:
186+
logging.debug(f"No stored token for domain {domain}")
187+
188+
def clear_all_tokens(self) -> None:
189+
"""Clear all tokens for all domains"""
190+
self.tokens.clear()
191+
logging.info("Cleared all tokens for all domains")
192+
193+
# # Example usage
194+
# async def example_usage():
195+
# # Get current script directory
196+
# current_dir = Path(__file__).parent
197+
# # Get project root directory (parent of current directory)
198+
# base_dir = current_dir.parent
199+
200+
# # Create client with absolute paths
201+
# client = DIDWbaAuthHeader(
202+
# did_document_path=str(base_dir / "use_did_test_public/did.json"),
203+
# private_key_path=str(base_dir / "use_did_test_public/key-1_private.pem")
204+
# )
205+
206+
# server_url = "http://localhost:9870"
207+
208+
# # Get authentication header (first call, returns DID authentication header)
209+
# headers = client.get_auth_header(server_url)
210+
211+
# # Send request
212+
# async with aiohttp.ClientSession() as session:
213+
# async with session.get(
214+
# f"{server_url}/agents/travel/hotel/ad/ph/12345/ad.json",
215+
# headers=headers
216+
# ) as response:
217+
# # Check response
218+
# print(f"Status code: {response.status}")
219+
220+
# # If authentication is successful, update token
221+
# if response.status == 200:
222+
# token = client.update_token(server_url, dict(response.headers))
223+
# if token:
224+
# print(f"Received token: {token[:30]}...")
225+
# else:
226+
# print("No token received in response headers")
227+
228+
# # If authentication fails and a token was used, clear the token and retry
229+
# elif response.status == 401:
230+
# print("Invalid token, clearing and using DID authentication")
231+
# client.clear_token(server_url)
232+
# # Retry request here
233+
234+
# # Get authentication header again (if a token was obtained in the previous step, this will return a token authentication header)
235+
# headers = client.get_auth_header(server_url)
236+
# print(f"Header for second request: {headers}")
237+
238+
# # Force use of DID authentication header
239+
# headers = client.get_auth_header(server_url, force_new=True)
240+
# print(f"Forced use of DID authentication header: {headers}")
241+
242+
# # Test different domain
243+
# another_server_url = "http://api.example.com"
244+
# headers = client.get_auth_header(another_server_url)
245+
# print(f"Header for another domain: {headers}")
246+
247+
# if __name__ == "__main__":
248+
# asyncio.run(example_usage())

examples/did_wba_examples/basic.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from cryptography.hazmat.primitives.asymmetric import ec
2323
from canonicaljson import encode_canonical_json
2424

25-
from agent_connect.authentication.did_wba import (
25+
from agent_connect.authentication import (
2626
create_did_wba_document,
2727
resolve_did_wba_document,
28-
generate_auth_header
28+
DIDWbaAuthHeader
2929
)
3030
from agent_connect.utils.log_base import set_log_color_level
3131

@@ -76,21 +76,23 @@ async def download_did_document(url: str) -> dict:
7676
logging.error("Failed to download DID document: %s", e)
7777
return None
7878

79-
async def test_did_auth(url: str, auth_header: str) -> tuple[bool, str]:
79+
async def test_did_auth(url: str, auth_client: DIDWbaAuthHeader) -> tuple[bool, str]:
8080
"""Test DID authentication and get token"""
8181
try:
8282
local_url = convert_url_for_local_testing(url)
8383
logging.info("Converting URL from %s to %s", url, local_url)
8484

85+
# 获取认证头
86+
auth_headers = auth_client.get_auth_header(local_url)
87+
8588
async with aiohttp.ClientSession() as session:
8689
async with session.get(
8790
local_url,
88-
headers={'Authorization': auth_header}
91+
headers=auth_headers
8992
) as response:
90-
token = response.headers.get('Authorization', '')
91-
if token.startswith('Bearer '):
92-
token = token[7:] # Remove 'Bearer ' prefix
93-
return response.status == 200, token
93+
# 更新令牌
94+
token = auth_client.update_token(local_url, dict(response.headers))
95+
return response.status == 200, token or ''
9496
except Exception as e:
9597
logging.error("DID authentication test failed: %s", e)
9698
return False, ''
@@ -117,31 +119,6 @@ def save_private_key(unique_id: str, keys: dict, did_document: dict) -> str:
117119

118120
return str(user_dir)
119121

120-
def load_private_key(private_key_dir: str, method_fragment: str) -> ec.EllipticCurvePrivateKey:
121-
"""Load private key from file"""
122-
key_dir = Path(private_key_dir)
123-
key_path = key_dir / f"{method_fragment}_private.pem"
124-
125-
logging.info("Loading private key from %s", key_path)
126-
with open(key_path, 'rb') as f:
127-
private_key_bytes = f.read()
128-
return serialization.load_pem_private_key(
129-
private_key_bytes,
130-
password=None
131-
)
132-
133-
def sign_callback(content: bytes, method_fragment: str) -> bytes:
134-
"""Sign content using private key"""
135-
# Load private key using the global variable
136-
private_key = load_private_key(sign_callback.private_key_dir, method_fragment)
137-
138-
# Sign the content
139-
signature = private_key.sign(
140-
content,
141-
ec.ECDSA(hashes.SHA256())
142-
)
143-
return signature
144-
145122
async def main(unique_id: str = None, agent_description_url: str = None):
146123
"""
147124
Main function to demonstrate DID WBA authentication
@@ -167,9 +144,10 @@ async def main(unique_id: str = None, agent_description_url: str = None):
167144
agent_description_url=agent_description_url
168145
)
169146

170-
# 4. Save private keys, DID document and set path for sign_callback
147+
# 4. Save private keys and DID document
171148
user_dir = save_private_key(unique_id, keys, did_document)
172-
sign_callback.private_key_dir = user_dir
149+
did_document_path = str(Path(user_dir) / "did.json")
150+
private_key_path = str(Path(user_dir) / "key-1_private.pem")
173151

174152
# 5. Upload DID document (This should be stored on your server)
175153
document_url = f"https://{server_domain}{did_path}"
@@ -180,24 +158,28 @@ async def main(unique_id: str = None, agent_description_url: str = None):
180158
return
181159
logging.info("DID document uploaded successfully")
182160

183-
# 7. Generate authentication header
184-
logging.info("Generating authentication header...")
185-
auth_header = generate_auth_header(
186-
did_document,
187-
server_domain,
188-
sign_callback
161+
# 6. 创建 DIDWbaAuthHeader 实例
162+
logging.info("Creating DIDWbaAuthHeader instance...")
163+
auth_client = DIDWbaAuthHeader(
164+
did_document_path=did_document_path,
165+
private_key_path=private_key_path
189166
)
190167

191-
# 8. Test DID authentication and get token
168+
# 7. Test DID authentication and get token
192169
test_url = f"https://{server_domain}/wba/test"
193170
logging.info("Testing DID authentication at %s", test_url)
194-
auth_success, token = await test_did_auth(test_url, auth_header)
171+
auth_success, token = await test_did_auth(test_url, auth_client)
195172

196-
if not auth_success or not token:
197-
logging.error(f"DID authentication test failed. auth_success: {auth_success}, token: {token}")
173+
if not auth_success:
174+
logging.error("DID authentication test failed")
198175
return
199176

200177
logging.info("DID authentication test successful")
178+
179+
if token:
180+
logging.info(f"Received token: {token}")
181+
else:
182+
logging.info("No token received from server")
201183

202184
if __name__ == "__main__":
203185
set_log_color_level(logging.INFO)

0 commit comments

Comments
 (0)