Skip to content

Commit 36b092a

Browse files
black formatting
1 parent b927aad commit 36b092a

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/agentlab/llm/tracking.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ def __call__(self, *args, **kwargs):
163163
response = self._call_api(*args, **kwargs)
164164

165165
usage = dict(getattr(response, "usage", {}))
166-
if 'prompt_tokens_details' in usage:
167-
usage['cached_tokens'] = usage['prompt_token_details'].cached_tokens
168-
if 'input_tokens_details' in usage:
169-
usage['cached_tokens'] = usage['input_tokens_details'].cached_tokens
166+
if "prompt_tokens_details" in usage:
167+
usage["cached_tokens"] = usage["prompt_token_details"].cached_tokens
168+
if "input_tokens_details" in usage:
169+
usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens
170170
usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))}
171171
usage |= {"n_api_calls": 1}
172172
usage |= {"effective_cost": self.get_effective_cost(response)}
@@ -306,21 +306,21 @@ def get_effective_cost_from_openai_api(self, response) -> float:
306306
if usage is None:
307307
logging.warning("No usage information found in the response. Defaulting cost to 0.0.")
308308
return 0.0
309-
api_type = 'chatcompletion' if hasattr(usage, "prompt_tokens_details") else 'response'
310-
if api_type == 'chatcompletion':
309+
api_type = "chatcompletion" if hasattr(usage, "prompt_tokens_details") else "response"
310+
if api_type == "chatcompletion":
311311
total_input_tokens = usage.prompt_tokens
312312
output_tokens = usage.completion_tokens
313313
cached_input_tokens = usage.prompt_tokens_details.cached_tokens
314314
non_cached_input_tokens = total_input_tokens - cached_input_tokens
315-
elif api_type == 'response':
315+
elif api_type == "response":
316316
total_input_tokens = usage.input_tokens
317317
output_tokens = usage.output_tokens
318318
cached_input_tokens = usage.input_tokens_details.cached_tokens
319319
non_cached_input_tokens = total_input_tokens - cached_input_tokens
320320
else:
321321
logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.")
322322
return 0.0
323-
323+
324324
cache_read_cost = self.input_cost * OPENAI_CACHE_PRICING_FACTOR["cache_read_tokens"]
325325
effective_cost = (
326326
self.input_cost * non_cached_input_tokens

0 commit comments

Comments
 (0)