Skip to content

Commit ecbeaa7

Browse files
committed
fix(pricing): pass request context to processor for usage extraction
1 parent 420aded commit ecbeaa7

File tree

4 files changed

+165
-120
lines changed

4 files changed

+165
-120
lines changed

ccproxy/services/http/plugin_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ async def _handle_regular_request(
186186
headers=dict(response.headers),
187187
status_code=response.status_code,
188188
handler_config=handler_config,
189+
request_context=request_context,
189190
)
190191

191192
result = Response(

ccproxy/services/http/processor.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ async def process_response(
9595
status_code: int,
9696
handler_config: HandlerConfig,
9797
request_headers: dict[str, str] | None = None,
98+
request_context: Any | None = None,
9899
) -> tuple[bytes, dict[str, str]]:
99100
"""Process response through adapters and transformers.
100101
@@ -104,10 +105,15 @@ async def process_response(
104105
status_code: HTTP status code
105106
handler_config: Handler configuration
106107
request_headers: Original request headers for CORS processing
108+
request_context: Optional request context for storing extracted data
107109
108110
Returns:
109111
Tuple of (processed_body, processed_headers)
110112
"""
113+
# Extract usage from original response BEFORE format conversion
114+
if request_context and status_code < 400:
115+
self._extract_usage_before_conversion(body, request_context)
116+
111117
# Apply response adapter for successful responses
112118
processed_body = body
113119
if handler_config.response_adapter and status_code < 400:
@@ -200,6 +206,69 @@ async def _apply_response_adapter(
200206
)
201207
return body
202208

209+
def _extract_usage_before_conversion(
210+
self, body: bytes, request_context: Any
211+
) -> None:
212+
"""Extract usage data from Anthropic response before format conversion.
213+
214+
Args:
215+
body: Response body in Anthropic format
216+
request_context: Request context to store usage data
217+
"""
218+
try:
219+
# Parse response body
220+
response_data = json.loads(body)
221+
usage = response_data.get("usage", {})
222+
223+
if not usage:
224+
return
225+
226+
# Extract Anthropic-specific usage fields
227+
tokens_input = usage.get("input_tokens", 0)
228+
tokens_output = usage.get("output_tokens", 0)
229+
cache_read_tokens = usage.get("cache_read_input_tokens", 0)
230+
231+
# Handle both old and new cache creation token formats
232+
cache_write_tokens = usage.get("cache_creation_input_tokens", 0)
233+
234+
# New format has cache_creation as nested object
235+
if "cache_creation" in usage and isinstance(usage["cache_creation"], dict):
236+
cache_creation = usage["cache_creation"]
237+
# Sum all cache creation tokens from different tiers
238+
cache_write_tokens = cache_creation.get(
239+
"ephemeral_5m_input_tokens", 0
240+
) + cache_creation.get("ephemeral_1h_input_tokens", 0)
241+
242+
# Update request context with usage data
243+
if hasattr(request_context, "metadata"):
244+
request_context.metadata.update(
245+
{
246+
"tokens_input": tokens_input,
247+
"tokens_output": tokens_output,
248+
"tokens_total": tokens_input + tokens_output,
249+
"cache_read_tokens": cache_read_tokens,
250+
"cache_write_tokens": cache_write_tokens,
251+
# Note: cost calculation happens in the adapter with pricing service
252+
}
253+
)
254+
255+
self.logger.debug(
256+
"usage_extracted_before_conversion",
257+
tokens_input=tokens_input,
258+
tokens_output=tokens_output,
259+
cache_read_tokens=cache_read_tokens,
260+
cache_write_tokens=cache_write_tokens,
261+
source="processor",
262+
)
263+
264+
except (json.JSONDecodeError, UnicodeDecodeError):
265+
# Silent fail - usage extraction is non-critical
266+
pass
267+
except Exception as e:
268+
self.logger.debug(
269+
"usage_extraction_failed", error=str(e), source="processor"
270+
)
271+
203272
def _filter_internal_headers(self, headers: dict[str, str]) -> dict[str, str]:
204273
"""Filter out internal headers that shouldn't be sent upstream.
205274

plugins/claude_api/adapter.py

Lines changed: 70 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ async def _execute_request(
347347
request_context=request_context, # Pass the actual RequestContext object
348348
)
349349

350+
# For non-streaming responses, calculate cost based on usage already extracted in processor
351+
if not is_streaming and request_context:
352+
await self._calculate_cost_for_usage(request_context)
353+
350354
# For deferred streaming responses, return directly (metrics collector already has cost calculation)
351355
if isinstance(response, DeferredStreaming):
352356
return response
@@ -355,131 +359,89 @@ async def _execute_request(
355359
if is_streaming and isinstance(response, StreamingResponse):
356360
return await self._wrap_streaming_response(response, request_context)
357361

358-
# For non-streaming responses, extract usage data if available
359-
if not is_streaming and hasattr(response, "body"):
360-
# Get response body (might be bytes or memoryview)
361-
response_body = response.body
362-
if isinstance(response_body, memoryview):
363-
response_body = bytes(response_body)
364-
await self._extract_usage_from_response(response_body, request_context)
365-
366362
return response
367363

368-
async def _extract_usage_from_response(
369-
self, body: bytes | str, request_context: "RequestContext"
364+
async def _calculate_cost_for_usage(
365+
self, request_context: "RequestContext"
370366
) -> None:
371-
"""Extract usage data from response body and update context.
372-
373-
Common function used by both streaming and non-streaming responses.
367+
"""Calculate cost for usage data already extracted in processor.
374368
375369
Args:
376-
body: Response body (bytes or string)
377-
request_context: Request context to update with usage data
370+
request_context: Request context with usage data from processor
378371
"""
379-
try:
380-
import json
381-
382-
# Convert body to string if needed
383-
body_str = body
384-
if isinstance(body_str, bytes):
385-
body_str = body_str.decode("utf-8")
386-
387-
# Parse response to extract usage
388-
response_data = json.loads(body_str)
389-
usage = response_data.get("usage", {})
372+
# Check if we have usage data from the processor
373+
metadata = request_context.metadata
374+
tokens_input = metadata.get("tokens_input", 0)
375+
tokens_output = metadata.get("tokens_output", 0)
390376

391-
if not usage:
392-
return
377+
# Skip if no usage data available
378+
if not (tokens_input or tokens_output):
379+
return
393380

394-
# Extract Claude-specific usage fields
395-
tokens_input = usage.get("input_tokens", 0)
396-
tokens_output = usage.get("output_tokens", 0)
397-
cache_read_tokens = usage.get("cache_read_input_tokens", 0)
398-
cache_write_tokens = usage.get("cache_creation_input_tokens", 0)
381+
# Get pricing service and calculate cost
382+
pricing_service = self._get_pricing_service()
383+
if not pricing_service:
384+
return
399385

400-
# Calculate cost using pricing service if available
401-
cost_usd = None
402-
pricing_service = self._get_pricing_service()
403-
self.logger.debug(
404-
"pricing_service_check",
405-
has_pricing_service=pricing_service is not None,
406-
source="non_streaming",
386+
try:
387+
model = metadata.get("model", "claude-3-5-sonnet-20241022")
388+
cache_read_tokens = metadata.get("cache_read_tokens", 0)
389+
cache_write_tokens = metadata.get("cache_write_tokens", 0)
390+
391+
# Import pricing exceptions
392+
from plugins.pricing.exceptions import (
393+
ModelPricingNotFoundError,
394+
PricingDataNotLoadedError,
395+
PricingServiceDisabledError,
407396
)
408-
if pricing_service:
409-
try:
410-
model = request_context.metadata.get(
411-
"model", "claude-3-5-sonnet-20241022"
412-
)
413-
# Import pricing exceptions
414-
from plugins.pricing.exceptions import (
415-
ModelPricingNotFoundError,
416-
PricingDataNotLoadedError,
417-
PricingServiceDisabledError,
418-
)
419-
420-
cost_decimal = await pricing_service.calculate_cost(
421-
model_name=model,
422-
input_tokens=tokens_input,
423-
output_tokens=tokens_output,
424-
cache_read_tokens=cache_read_tokens,
425-
cache_write_tokens=cache_write_tokens,
426-
)
427-
cost_usd = float(cost_decimal)
428-
self.logger.debug(
429-
"cost_calculated",
430-
model=model,
431-
cost_usd=cost_usd,
432-
tokens_input=tokens_input,
433-
tokens_output=tokens_output,
434-
)
435-
except ModelPricingNotFoundError as e:
436-
self.logger.warning(
437-
"model_pricing_not_found",
438-
model=model,
439-
message=str(e),
440-
tokens_input=tokens_input,
441-
tokens_output=tokens_output,
442-
)
443-
except PricingDataNotLoadedError as e:
444-
self.logger.warning(
445-
"pricing_data_not_loaded",
446-
model=model,
447-
message=str(e),
448-
)
449-
except PricingServiceDisabledError as e:
450-
self.logger.debug(
451-
"pricing_service_disabled",
452-
message=str(e),
453-
)
454-
except Exception as e:
455-
self.logger.debug(
456-
"cost_calculation_failed", error=str(e), model=model
457-
)
458397

459-
# Update request context with usage data
460-
request_context.metadata.update(
461-
{
462-
"tokens_input": tokens_input,
463-
"tokens_output": tokens_output,
464-
"tokens_total": tokens_input + tokens_output,
465-
"cache_read_tokens": cache_read_tokens,
466-
"cache_write_tokens": cache_write_tokens,
467-
"cost_usd": cost_usd or 0.0,
468-
}
398+
cost_decimal = await pricing_service.calculate_cost(
399+
model_name=model,
400+
input_tokens=tokens_input,
401+
output_tokens=tokens_output,
402+
cache_read_tokens=cache_read_tokens,
403+
cache_write_tokens=cache_write_tokens,
469404
)
405+
cost_usd = float(cost_decimal)
406+
407+
# Update context with calculated cost
408+
metadata["cost_usd"] = cost_usd
470409

471410
self.logger.debug(
472-
"usage_extracted",
411+
"cost_calculated",
412+
model=model,
413+
cost_usd=cost_usd,
473414
tokens_input=tokens_input,
474415
tokens_output=tokens_output,
475416
cache_read_tokens=cache_read_tokens,
476417
cache_write_tokens=cache_write_tokens,
477-
cost_usd=cost_usd,
478-
source="response_body",
418+
source="non_streaming",
419+
)
420+
except ModelPricingNotFoundError as e:
421+
self.logger.warning(
422+
"model_pricing_not_found",
423+
model=model,
424+
message=str(e),
425+
tokens_input=tokens_input,
426+
tokens_output=tokens_output,
427+
)
428+
except PricingDataNotLoadedError as e:
429+
self.logger.warning(
430+
"pricing_data_not_loaded",
431+
model=model,
432+
message=str(e),
433+
)
434+
except PricingServiceDisabledError as e:
435+
self.logger.debug(
436+
"pricing_service_disabled",
437+
message=str(e),
479438
)
480-
481439
except Exception as e:
482-
self.logger.debug("usage_extraction_failed", error=str(e))
440+
self.logger.debug(
441+
"cost_calculation_failed",
442+
error=str(e),
443+
model=metadata.get("model"),
444+
)
483445

484446
async def _wrap_streaming_response(
485447
self, response: StreamingResponse, request_context: "RequestContext"
@@ -638,7 +600,9 @@ async def wrapped_iterator() -> AsyncIterator[bytes]:
638600
model=model,
639601
message=str(e),
640602
tokens_input=usage_metrics.get("tokens_input"),
641-
tokens_output=usage_metrics.get("tokens_output"),
603+
tokens_output=usage_metrics.get(
604+
"tokens_output"
605+
),
642606
category="pricing",
643607
)
644608
except PricingDataNotLoadedError as e:

tests/unit/plugins/test_claude_api_pricing.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_get_pricing_service_with_missing_runtime(self):
102102
async def test_extract_usage_with_pricing(
103103
self, adapter_with_pricing, mock_pricing_service
104104
):
105-
"""Test that usage extraction uses pricing service for cost calculation."""
105+
"""Test that cost calculation uses pricing service when available."""
106106
import time
107107

108108
from ccproxy.observability.context import RequestContext
@@ -113,14 +113,19 @@ async def test_extract_usage_with_pricing(
113113
)
114114
request_context.metadata["model"] = "claude-3-5-sonnet-20241022"
115115

116-
# Mock response body with usage data
117-
response_body = b'{"usage": {"input_tokens": 1000, "output_tokens": 500}}'
118-
119-
# Extract usage from response
120-
await adapter_with_pricing._extract_usage_from_response(
121-
response_body, request_context
116+
# Simulate usage data already extracted in processor
117+
request_context.metadata.update(
118+
{
119+
"tokens_input": 1000,
120+
"tokens_output": 500,
121+
"cache_read_tokens": 0,
122+
"cache_write_tokens": 0,
123+
}
122124
)
123125

126+
# Calculate cost with pricing service
127+
await adapter_with_pricing._calculate_cost_for_usage(request_context)
128+
124129
# Verify pricing service was called
125130
mock_pricing_service.calculate_cost.assert_called_once_with(
126131
model_name="claude-3-5-sonnet-20241022",
@@ -156,16 +161,22 @@ async def test_extract_usage_without_pricing(self):
156161
)
157162
request_context.metadata["model"] = "claude-3-5-sonnet-20241022"
158163

159-
# Mock response body with usage data
160-
response_body = b'{"usage": {"input_tokens": 1000, "output_tokens": 500}}'
164+
# Simulate usage data already extracted in processor
165+
request_context.metadata.update(
166+
{
167+
"tokens_input": 1000,
168+
"tokens_output": 500,
169+
"tokens_total": 1500,
170+
}
171+
)
161172

162-
# Extract usage from response (should not fail)
163-
await adapter._extract_usage_from_response(response_body, request_context)
173+
# Calculate cost (should not fail even without pricing service)
174+
await adapter._calculate_cost_for_usage(request_context)
164175

165-
# Verify tokens were extracted even without pricing
176+
# Verify tokens are still in metadata
166177
assert request_context.metadata["tokens_input"] == 1000
167178
assert request_context.metadata["tokens_output"] == 500
168179
assert request_context.metadata["tokens_total"] == 1500
169180

170-
# Cost should be 0 when pricing service is not available
171-
assert request_context.metadata["cost_usd"] == 0.0
181+
# Cost should not be set when pricing service is not available
182+
assert "cost_usd" not in request_context.metadata

0 commit comments

Comments
 (0)