1+ # AgentConnect: https://github.com/agent-network-protocol/AgentConnect
2+ # Author: GaoWei Chang
3+ # Email: chgaowei@gmail.com
4+ # Website: https://agent-network-protocol.com/
5+ #
6+ # This project is open-sourced under the MIT License. For details, please see the LICENSE file.
7+
8+ import os
9+ import json
10+ import logging
11+ import asyncio
12+ import aiohttp
13+ from pathlib import Path
14+ from typing import Dict , Optional
15+ from cryptography .hazmat .primitives .asymmetric import ec
16+ from cryptography .hazmat .primitives import serialization
17+ from cryptography .hazmat .primitives import hashes
18+ from urllib .parse import urlparse
19+
20+ # Import agent_connect for DID authentication
21+ from .did_wba import (
22+ generate_auth_header
23+ )
24+
25+ class DIDWbaAuthHeader :
26+ """
27+ Simplified DID authentication client providing HTTP authentication headers.
28+ """
29+
30+ def __init__ (self , did_document_path : str , private_key_path : str ):
31+ """
32+ Initialize the DID authentication client.
33+
34+ Args:
35+ did_document_path: Path to the DID document (absolute or relative path)
36+ private_key_path: Path to the private key (absolute or relative path)
37+ """
38+ self .did_document_path = did_document_path
39+ self .private_key_path = private_key_path
40+
41+ # State variables
42+ self .did_document = None
43+ self .auth_headers = {} # Store DID authentication headers by domain
44+ self .tokens = {} # Store tokens by domain
45+
46+ logging .info ("DIDWbaAuthHeader initialized" )
47+
48+ def _get_domain (self , server_url : str ) -> str :
49+ """Extract domain from URL"""
50+ parsed_url = urlparse (server_url )
51+ domain = parsed_url .netloc .split (':' )[0 ]
52+ return domain
53+
54+ def _load_did_document (self ) -> Dict :
55+ """Load DID document"""
56+ try :
57+ if self .did_document :
58+ return self .did_document
59+
60+ # Use the provided path directly, without resolving absolute path
61+ did_path = self .did_document_path
62+
63+ with open (did_path , 'r' ) as f :
64+ did_document = json .load (f )
65+
66+ self .did_document = did_document
67+ logging .info (f"Loaded DID document: { did_path } " )
68+ return did_document
69+ except Exception as e :
70+ logging .error (f"Error loading DID document: { e } " )
71+ raise
72+
73+ def _load_private_key (self ) -> ec .EllipticCurvePrivateKey :
74+ """Load private key"""
75+ try :
76+ # Use the provided path directly, without resolving absolute path
77+ key_path = self .private_key_path
78+
79+ with open (key_path , 'rb' ) as f :
80+ private_key_data = f .read ()
81+
82+ private_key = serialization .load_pem_private_key (
83+ private_key_data ,
84+ password = None
85+ )
86+
87+ logging .debug (f"Loaded private key: { key_path } " )
88+ return private_key
89+ except Exception as e :
90+ logging .error (f"Error loading private key: { e } " )
91+ raise
92+
93+ def _sign_callback (self , content : bytes , method_fragment : str ) -> bytes :
94+ """Sign callback function"""
95+ try :
96+ private_key = self ._load_private_key ()
97+ signature = private_key .sign (
98+ content ,
99+ ec .ECDSA (hashes .SHA256 ())
100+ )
101+
102+ logging .debug (f"Signed content with method fragment: { method_fragment } " )
103+ return signature
104+ except Exception as e :
105+ logging .error (f"Error signing content: { e } " )
106+ raise
107+
108+ def _generate_auth_header (self , domain : str ) -> str :
109+ """Generate DID authentication header"""
110+ try :
111+ did_document = self ._load_did_document ()
112+
113+ auth_header = generate_auth_header (
114+ did_document ,
115+ domain ,
116+ self ._sign_callback
117+ )
118+
119+ logging .info (f"Generated authentication header for domain { domain } : { auth_header [:30 ]} ..." )
120+ return auth_header
121+ except Exception as e :
122+ logging .error (f"Error generating authentication header: { e } " )
123+ raise
124+
125+ def get_auth_header (self , server_url : str , force_new : bool = False ) -> Dict [str , str ]:
126+ """
127+ Get authentication header.
128+
129+ Args:
130+ server_url: Server URL
131+ force_new: Whether to force generate a new DID authentication header
132+
133+ Returns:
134+ Dict[str, str]: HTTP header dictionary
135+ """
136+ domain = self ._get_domain (server_url )
137+
138+ # If there is a token and not forcing a new authentication header, return the token
139+ if domain in self .tokens and not force_new :
140+ token = self .tokens [domain ]
141+ logging .info (f"Using existing token for domain { domain } " )
142+ return {"Authorization" : f"Bearer { token } " }
143+
144+ # Otherwise, generate or use existing DID authentication header
145+ if domain not in self .auth_headers or force_new :
146+ self .auth_headers [domain ] = self ._generate_auth_header (domain )
147+
148+ logging .info (f"Using DID authentication header for domain { domain } " )
149+ return {"Authorization" : self .auth_headers [domain ]}
150+
151+ def update_token (self , server_url : str , headers : Dict [str , str ]) -> Optional [str ]:
152+ """
153+ Update token from response headers.
154+
155+ Args:
156+ server_url: Server URL
157+ headers: Response header dictionary
158+
159+ Returns:
160+ Optional[str]: Updated token, or None if no valid token is found
161+ """
162+ domain = self ._get_domain (server_url )
163+ auth_header = headers .get ("Authorization" )
164+
165+ if auth_header and auth_header .lower ().startswith ("bearer " ):
166+ token = auth_header [7 :] # Remove "Bearer " prefix
167+ self .tokens [domain ] = token
168+ logging .info (f"Updated token for domain { domain } : { token [:30 ]} ..." )
169+ return token
170+ else :
171+ logging .debug (f"No valid token found in response headers for domain { domain } " )
172+ return None
173+
174+ def clear_token (self , server_url : str ) -> None :
175+ """
176+ Clear token for the specified domain.
177+
178+ Args:
179+ server_url: Server URL
180+ """
181+ domain = self ._get_domain (server_url )
182+ if domain in self .tokens :
183+ del self .tokens [domain ]
184+ logging .info (f"Cleared token for domain { domain } " )
185+ else :
186+ logging .debug (f"No stored token for domain { domain } " )
187+
188+ def clear_all_tokens (self ) -> None :
189+ """Clear all tokens for all domains"""
190+ self .tokens .clear ()
191+ logging .info ("Cleared all tokens for all domains" )
192+
193+ # # Example usage
194+ # async def example_usage():
195+ # # Get current script directory
196+ # current_dir = Path(__file__).parent
197+ # # Get project root directory (parent of current directory)
198+ # base_dir = current_dir.parent
199+
200+ # # Create client with absolute paths
201+ # client = DIDWbaAuthHeader(
202+ # did_document_path=str(base_dir / "use_did_test_public/did.json"),
203+ # private_key_path=str(base_dir / "use_did_test_public/key-1_private.pem")
204+ # )
205+
206+ # server_url = "http://localhost:9870"
207+
208+ # # Get authentication header (first call, returns DID authentication header)
209+ # headers = client.get_auth_header(server_url)
210+
211+ # # Send request
212+ # async with aiohttp.ClientSession() as session:
213+ # async with session.get(
214+ # f"{server_url}/agents/travel/hotel/ad/ph/12345/ad.json",
215+ # headers=headers
216+ # ) as response:
217+ # # Check response
218+ # print(f"Status code: {response.status}")
219+
220+ # # If authentication is successful, update token
221+ # if response.status == 200:
222+ # token = client.update_token(server_url, dict(response.headers))
223+ # if token:
224+ # print(f"Received token: {token[:30]}...")
225+ # else:
226+ # print("No token received in response headers")
227+
228+ # # If authentication fails and a token was used, clear the token and retry
229+ # elif response.status == 401:
230+ # print("Invalid token, clearing and using DID authentication")
231+ # client.clear_token(server_url)
232+ # # Retry request here
233+
234+ # # Get authentication header again (if a token was obtained in the previous step, this will return a token authentication header)
235+ # headers = client.get_auth_header(server_url)
236+ # print(f"Header for second request: {headers}")
237+
238+ # # Force use of DID authentication header
239+ # headers = client.get_auth_header(server_url, force_new=True)
240+ # print(f"Forced use of DID authentication header: {headers}")
241+
242+ # # Test different domain
243+ # another_server_url = "http://api.example.com"
244+ # headers = client.get_auth_header(another_server_url)
245+ # print(f"Header for another domain: {headers}")
246+
247+ # if __name__ == "__main__":
248+ # asyncio.run(example_usage())
0 commit comments