@@ -54,10 +54,13 @@ def __init__(
5454 # Store any additional kwargs for potential future use
5555 self ._custom_kwargs = kwargs
5656
57+ # Filter headers to prevent override of critical headers
58+ filtered_headers = self ._filter_user_headers (headers ) if headers else None
59+
5760 # Initialize parent with only known parameters
5861 parent_kwargs = {}
59- if headers is not None :
60- parent_kwargs ["headers" ] = headers
62+ if filtered_headers is not None :
63+ parent_kwargs ["headers" ] = filtered_headers
6164 if timeout is not None :
6265 parent_kwargs ["timeout" ] = timeout
6366 if compression is not None :
@@ -66,24 +69,49 @@ def __init__(
6669 super ().__init__ (endpoint = endpoint , ** parent_kwargs )
6770
6871 def _get_current_jwt (self ) -> Optional [str ]:
69- """Get the current JWT token from the provider."""
72+ """Get the current JWT token from the provider or stored JWT ."""
7073 if self ._jwt_provider :
7174 try :
7275 return self ._jwt_provider ()
7376 except Exception as e :
7477 logger .warning (f"Failed to get JWT token: { e } " )
75- return None
78+ return self ._jwt
79+
80+ def _filter_user_headers (self , headers : Optional [Dict [str , str ]]) -> Optional [Dict [str , str ]]:
81+ """Filter user-supplied headers to prevent override of critical headers."""
82+ if not headers :
83+ return None
84+
85+ # Define critical headers that cannot be overridden by user-supplied headers
86+ PROTECTED_HEADERS = {
87+ "authorization" ,
88+ "content-type" ,
89+ "user-agent" ,
90+ "x-api-key" ,
91+ "api-key" ,
92+ "bearer" ,
93+ "x-auth-token" ,
94+ "x-session-token" ,
95+ }
96+
97+ filtered_headers = {}
98+ for key , value in headers .items ():
99+ if key .lower () not in PROTECTED_HEADERS :
100+ filtered_headers [key ] = value
101+
102+ return filtered_headers if filtered_headers else None
76103
77104 def _prepare_headers (self , headers : Optional [Dict [str , str ]] = None ) -> Dict [str , str ]:
78105 """Prepare headers with current JWT token."""
79106 # Start with base headers
80107 prepared_headers = dict (self ._headers )
81108
82- # Add any additional headers
83- if headers :
84- prepared_headers .update (headers )
109+ # Add any additional headers, but only allow non-critical headers
110+ filtered_headers = self ._filter_user_headers (headers )
111+ if filtered_headers :
112+ prepared_headers .update (filtered_headers )
85113
86- # Add current JWT token if available
114+ # Add current JWT token if available (this ensures Authorization cannot be overridden)
87115 jwt_token = self ._get_current_jwt ()
88116 if jwt_token :
89117 prepared_headers ["Authorization" ] = f"Bearer { jwt_token } "
0 commit comments