@@ -347,6 +347,10 @@ async def _execute_request(
347347 request_context = request_context , # Pass the actual RequestContext object
348348 )
349349
350+ # For non-streaming responses, calculate cost based on usage already extracted in processor
351+ if not is_streaming and request_context :
352+ await self ._calculate_cost_for_usage (request_context )
353+
350354 # For deferred streaming responses, return directly (metrics collector already has cost calculation)
351355 if isinstance (response , DeferredStreaming ):
352356 return response
@@ -355,131 +359,89 @@ async def _execute_request(
355359 if is_streaming and isinstance (response , StreamingResponse ):
356360 return await self ._wrap_streaming_response (response , request_context )
357361
358- # For non-streaming responses, extract usage data if available
359- if not is_streaming and hasattr (response , "body" ):
360- # Get response body (might be bytes or memoryview)
361- response_body = response .body
362- if isinstance (response_body , memoryview ):
363- response_body = bytes (response_body )
364- await self ._extract_usage_from_response (response_body , request_context )
365-
366362 return response
367363
368- async def _extract_usage_from_response (
369- self , body : bytes | str , request_context : "RequestContext"
364+ async def _calculate_cost_for_usage (
365+ self , request_context : "RequestContext"
370366 ) -> None :
371- """Extract usage data from response body and update context.
372-
373- Common function used by both streaming and non-streaming responses.
367+ """Calculate cost for usage data already extracted in processor.
374368
375369 Args:
376- body: Response body (bytes or string)
377- request_context: Request context to update with usage data
370+ request_context: Request context with usage data from processor
378371 """
379- try :
380- import json
381-
382- # Convert body to string if needed
383- body_str = body
384- if isinstance (body_str , bytes ):
385- body_str = body_str .decode ("utf-8" )
386-
387- # Parse response to extract usage
388- response_data = json .loads (body_str )
389- usage = response_data .get ("usage" , {})
372+ # Check if we have usage data from the processor
373+ metadata = request_context .metadata
374+ tokens_input = metadata .get ("tokens_input" , 0 )
375+ tokens_output = metadata .get ("tokens_output" , 0 )
390376
391- if not usage :
392- return
377+ # Skip if no usage data available
378+ if not (tokens_input or tokens_output ):
379+ return
393380
394- # Extract Claude-specific usage fields
395- tokens_input = usage .get ("input_tokens" , 0 )
396- tokens_output = usage .get ("output_tokens" , 0 )
397- cache_read_tokens = usage .get ("cache_read_input_tokens" , 0 )
398- cache_write_tokens = usage .get ("cache_creation_input_tokens" , 0 )
381+ # Get pricing service and calculate cost
382+ pricing_service = self ._get_pricing_service ()
383+ if not pricing_service :
384+ return
399385
400- # Calculate cost using pricing service if available
401- cost_usd = None
402- pricing_service = self ._get_pricing_service ()
403- self .logger .debug (
404- "pricing_service_check" ,
405- has_pricing_service = pricing_service is not None ,
406- source = "non_streaming" ,
386+ try :
387+ model = metadata .get ("model" , "claude-3-5-sonnet-20241022" )
388+ cache_read_tokens = metadata .get ("cache_read_tokens" , 0 )
389+ cache_write_tokens = metadata .get ("cache_write_tokens" , 0 )
390+
391+ # Import pricing exceptions
392+ from plugins .pricing .exceptions import (
393+ ModelPricingNotFoundError ,
394+ PricingDataNotLoadedError ,
395+ PricingServiceDisabledError ,
407396 )
408- if pricing_service :
409- try :
410- model = request_context .metadata .get (
411- "model" , "claude-3-5-sonnet-20241022"
412- )
413- # Import pricing exceptions
414- from plugins .pricing .exceptions import (
415- ModelPricingNotFoundError ,
416- PricingDataNotLoadedError ,
417- PricingServiceDisabledError ,
418- )
419-
420- cost_decimal = await pricing_service .calculate_cost (
421- model_name = model ,
422- input_tokens = tokens_input ,
423- output_tokens = tokens_output ,
424- cache_read_tokens = cache_read_tokens ,
425- cache_write_tokens = cache_write_tokens ,
426- )
427- cost_usd = float (cost_decimal )
428- self .logger .debug (
429- "cost_calculated" ,
430- model = model ,
431- cost_usd = cost_usd ,
432- tokens_input = tokens_input ,
433- tokens_output = tokens_output ,
434- )
435- except ModelPricingNotFoundError as e :
436- self .logger .warning (
437- "model_pricing_not_found" ,
438- model = model ,
439- message = str (e ),
440- tokens_input = tokens_input ,
441- tokens_output = tokens_output ,
442- )
443- except PricingDataNotLoadedError as e :
444- self .logger .warning (
445- "pricing_data_not_loaded" ,
446- model = model ,
447- message = str (e ),
448- )
449- except PricingServiceDisabledError as e :
450- self .logger .debug (
451- "pricing_service_disabled" ,
452- message = str (e ),
453- )
454- except Exception as e :
455- self .logger .debug (
456- "cost_calculation_failed" , error = str (e ), model = model
457- )
458397
459- # Update request context with usage data
460- request_context .metadata .update (
461- {
462- "tokens_input" : tokens_input ,
463- "tokens_output" : tokens_output ,
464- "tokens_total" : tokens_input + tokens_output ,
465- "cache_read_tokens" : cache_read_tokens ,
466- "cache_write_tokens" : cache_write_tokens ,
467- "cost_usd" : cost_usd or 0.0 ,
468- }
398+ cost_decimal = await pricing_service .calculate_cost (
399+ model_name = model ,
400+ input_tokens = tokens_input ,
401+ output_tokens = tokens_output ,
402+ cache_read_tokens = cache_read_tokens ,
403+ cache_write_tokens = cache_write_tokens ,
469404 )
405+ cost_usd = float (cost_decimal )
406+
407+ # Update context with calculated cost
408+ metadata ["cost_usd" ] = cost_usd
470409
471410 self .logger .debug (
472- "usage_extracted" ,
411+ "cost_calculated" ,
412+ model = model ,
413+ cost_usd = cost_usd ,
473414 tokens_input = tokens_input ,
474415 tokens_output = tokens_output ,
475416 cache_read_tokens = cache_read_tokens ,
476417 cache_write_tokens = cache_write_tokens ,
477- cost_usd = cost_usd ,
478- source = "response_body" ,
418+ source = "non_streaming" ,
419+ )
420+ except ModelPricingNotFoundError as e :
421+ self .logger .warning (
422+ "model_pricing_not_found" ,
423+ model = model ,
424+ message = str (e ),
425+ tokens_input = tokens_input ,
426+ tokens_output = tokens_output ,
427+ )
428+ except PricingDataNotLoadedError as e :
429+ self .logger .warning (
430+ "pricing_data_not_loaded" ,
431+ model = model ,
432+ message = str (e ),
433+ )
434+ except PricingServiceDisabledError as e :
435+ self .logger .debug (
436+ "pricing_service_disabled" ,
437+ message = str (e ),
479438 )
480-
481439 except Exception as e :
482- self .logger .debug ("usage_extraction_failed" , error = str (e ))
440+ self .logger .debug (
441+ "cost_calculation_failed" ,
442+ error = str (e ),
443+ model = metadata .get ("model" ),
444+ )
483445
484446 async def _wrap_streaming_response (
485447 self , response : StreamingResponse , request_context : "RequestContext"
@@ -638,7 +600,9 @@ async def wrapped_iterator() -> AsyncIterator[bytes]:
638600 model = model ,
639601 message = str (e ),
640602 tokens_input = usage_metrics .get ("tokens_input" ),
641- tokens_output = usage_metrics .get ("tokens_output" ),
603+ tokens_output = usage_metrics .get (
604+ "tokens_output"
605+ ),
642606 category = "pricing" ,
643607 )
644608 except PricingDataNotLoadedError as e :
0 commit comments