Skip to content

Commit e9ee842

Browse files
committed
Refactor header handling in AuthenticatedOTLPExporter to prevent critical headers from being overridden by user-supplied values. Update tests to verify protection of critical headers and ensure proper JWT token usage.
1 parent beca4a8 commit e9ee842

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

agentops/client/api/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def async_request(
9999
url = self._get_full_url(path)
100100

101101
try:
102-
response_data = await HttpClient.async_request(
102+
response_data = await self.http_client.async_request(
103103
method=method, url=url, data=data, headers=headers, timeout=timeout
104104
)
105105
return response_data

agentops/sdk/exporters.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ def __init__(
5454
# Store any additional kwargs for potential future use
5555
self._custom_kwargs = kwargs
5656

57+
# Filter headers to prevent override of critical headers
58+
filtered_headers = self._filter_user_headers(headers) if headers else None
59+
5760
# Initialize parent with only known parameters
5861
parent_kwargs = {}
59-
if headers is not None:
60-
parent_kwargs["headers"] = headers
62+
if filtered_headers is not None:
63+
parent_kwargs["headers"] = filtered_headers
6164
if timeout is not None:
6265
parent_kwargs["timeout"] = timeout
6366
if compression is not None:
@@ -66,24 +69,49 @@ def __init__(
6669
super().__init__(endpoint=endpoint, **parent_kwargs)
6770

6871
def _get_current_jwt(self) -> Optional[str]:
69-
"""Get the current JWT token from the provider."""
72+
"""Get the current JWT token from the provider or stored JWT."""
7073
if self._jwt_provider:
7174
try:
7275
return self._jwt_provider()
7376
except Exception as e:
7477
logger.warning(f"Failed to get JWT token: {e}")
75-
return None
78+
return self._jwt
79+
80+
def _filter_user_headers(self, headers: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]:
81+
"""Filter user-supplied headers to prevent override of critical headers."""
82+
if not headers:
83+
return None
84+
85+
# Define critical headers that cannot be overridden by user-supplied headers
86+
PROTECTED_HEADERS = {
87+
"authorization",
88+
"content-type",
89+
"user-agent",
90+
"x-api-key",
91+
"api-key",
92+
"bearer",
93+
"x-auth-token",
94+
"x-session-token",
95+
}
96+
97+
filtered_headers = {}
98+
for key, value in headers.items():
99+
if key.lower() not in PROTECTED_HEADERS:
100+
filtered_headers[key] = value
101+
102+
return filtered_headers if filtered_headers else None
76103

77104
def _prepare_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
78105
"""Prepare headers with current JWT token."""
79106
# Start with base headers
80107
prepared_headers = dict(self._headers)
81108

82-
# Add any additional headers
83-
if headers:
84-
prepared_headers.update(headers)
109+
# Add any additional headers, but only allow non-critical headers
110+
filtered_headers = self._filter_user_headers(headers)
111+
if filtered_headers:
112+
prepared_headers.update(filtered_headers)
85113

86-
# Add current JWT token if available
114+
# Add current JWT token if available (this ensures Authorization cannot be overridden)
87115
jwt_token = self._get_current_jwt()
88116
if jwt_token:
89117
prepared_headers["Authorization"] = f"Bearer {jwt_token}"

tests/unit/sdk/test_exporters.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,30 @@ def test_headers_merging(self):
172172
# Verify the exporter was created successfully
173173
self.assertIsInstance(exporter, AuthenticatedOTLPExporter)
174174

175-
def test_headers_override_authorization(self):
176-
"""Test that custom Authorization header overrides the default one."""
177-
custom_headers = {"Authorization": "Custom-Auth custom-token", "X-Custom-Header": "test-value"}
178-
179-
exporter = AuthenticatedOTLPExporter(endpoint=self.endpoint, jwt=self.jwt, headers=custom_headers)
175+
def test_headers_protected_from_override(self):
176+
"""Test that critical headers cannot be overridden by user-supplied headers."""
177+
# Attempt to override critical headers
178+
malicious_headers = {
179+
"Authorization": "Malicious-Auth malicious-token",
180+
"Content-Type": "text/plain",
181+
"User-Agent": "malicious-agent",
182+
"X-API-Key": "malicious-key",
183+
"X-Custom-Header": "test-value", # This should be allowed
184+
}
185+
186+
exporter = AuthenticatedOTLPExporter(endpoint=self.endpoint, jwt=self.jwt, headers=malicious_headers)
187+
188+
# Test the _prepare_headers method directly to verify protection
189+
prepared_headers = exporter._prepare_headers(malicious_headers)
190+
191+
# Critical headers should not be overridden
192+
self.assertEqual(prepared_headers["Authorization"], f"Bearer {self.jwt}")
193+
self.assertNotEqual(prepared_headers.get("Content-Type"), "text/plain")
194+
self.assertNotEqual(prepared_headers.get("User-Agent"), "malicious-agent")
195+
self.assertNotIn("X-API-Key", prepared_headers) # Should be filtered out
196+
197+
# Non-critical headers should be allowed
198+
self.assertEqual(prepared_headers.get("X-Custom-Header"), "test-value")
180199

181200
# Verify the exporter was created successfully
182201
self.assertIsInstance(exporter, AuthenticatedOTLPExporter)

0 commit comments

Comments
 (0)