|
5 | 5 | from ..core import get_logger |
6 | 6 | from ..core.db import AsyncSession |
7 | 7 | from ..core.settings import settings |
| 8 | +from .price import sats_usd_price |
8 | 9 |
|
9 | 10 | logger = get_logger(__name__) |
10 | 11 |
|
@@ -64,6 +65,56 @@ async def calculate_cost( # todo: can be sync |
64 | 65 | ) |
65 | 66 | return cost_data |
66 | 67 |
|
| 68 | + usage_data = response_data["usage"] |
| 69 | + |
| 70 | + usd_cost = 0.0 |
| 71 | + |
| 72 | + # Prioritize cost_details.upstream_inference_cost |
| 73 | + if "cost_details" in usage_data: |
| 74 | + usd_cost = float( |
| 75 | + usage_data["cost_details"].get("upstream_inference_cost", 0) or 0 |
| 76 | + ) |
| 77 | + |
| 78 | + # Fallback to cost field if upstream_inference_cost is 0 |
| 79 | + if usd_cost == 0 and "cost" in usage_data: |
| 80 | + try: |
| 81 | + usd_cost = float(usage_data.get("cost", 0) or 0) |
| 82 | + except Exception: |
| 83 | + pass |
| 84 | + |
| 85 | + if usd_cost > 0: |
| 86 | + try: |
| 87 | + sats_per_usd = 1.0 / sats_usd_price() |
| 88 | + cost_in_sats = usd_cost * sats_per_usd |
| 89 | + cost_in_msats = math.ceil(cost_in_sats * 1000) |
| 90 | + |
| 91 | + logger.info( |
| 92 | + "Using cost from usage data/details", |
| 93 | + extra={ |
| 94 | + "usd_cost": usd_cost, |
| 95 | + "cost_in_sats": cost_in_sats, |
| 96 | + "cost_in_msats": cost_in_msats, |
| 97 | + "model": response_data.get("model", "unknown"), |
| 98 | + }, |
| 99 | + ) |
| 100 | + |
| 101 | + return CostData( |
| 102 | + base_msats=-1, |
| 103 | + input_msats=-1, # Cost field doesn't break down by token type |
| 104 | + output_msats=-1, |
| 105 | + total_msats=cost_in_msats, |
| 106 | + ) |
| 107 | + except Exception as e: |
| 108 | + logger.warning( |
| 109 | + "Error calculating cost from usage data", |
| 110 | + extra={ |
| 111 | + "error": str(e), |
| 112 | + "usd_cost": usd_cost, |
| 113 | + "model": response_data.get("model", "unknown"), |
| 114 | + }, |
| 115 | + ) |
| 116 | + # Fall through to token-based calculation |
| 117 | + |
67 | 118 | MSATS_PER_1K_INPUT_TOKENS: float = ( |
68 | 119 | float(settings.fixed_per_1k_input_tokens) * 1000.0 |
69 | 120 | ) |
@@ -129,10 +180,19 @@ async def calculate_cost( # todo: can be sync |
129 | 180 | ) |
130 | 181 | return cost_data |
131 | 182 |
|
132 | | - input_tokens = response_data.get("usage", {}).get("prompt_tokens", 0) |
133 | | - output_tokens = response_data.get("usage", {}).get("completion_tokens", 0) |
| 183 | + input_tokens = usage_data.get("prompt_tokens", 0) |
| 184 | + output_tokens = usage_data.get("completion_tokens", 0) |
| 185 | + |
| 186 | + # added for response api |
| 187 | + input_tokens = ( |
| 188 | + input_tokens if input_tokens != 0 else usage_data.get("input_tokens", 0) |
| 189 | + ) |
| 190 | + output_tokens = ( |
| 191 | + output_tokens if output_tokens != 0 else usage_data.get("output_tokens", 0) |
| 192 | + ) |
134 | 193 |
|
135 | 194 | input_msats = round(input_tokens / 1000 * MSATS_PER_1K_INPUT_TOKENS, 3) |
| 195 | + |
136 | 196 | output_msats = round(output_tokens / 1000 * MSATS_PER_1K_OUTPUT_TOKENS, 3) |
137 | 197 | token_based_cost = math.ceil(input_msats + output_msats) |
138 | 198 |
|
|
0 commit comments