Skip to content

Commit e51033c

Browse files
Feature: Added generic "usage" stats tracking for API that support "usage" key in thier responses.
1 parent 27da915 commit e51033c

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def obs_preprocessor(self, obs):
294294

295295
@cost_tracker_decorator
296296
def get_action(self, obs: Any) -> float:
297+
self.llm.reset_stats()
297298
if len(self.messages) == 0:
298299
self.config.goal.apply(self.llm, self.messages, obs)
299300
self.config.general_hints.apply(self.llm, self.messages)
@@ -309,7 +310,11 @@ def get_action(self, obs: Any) -> float:
309310
self._responses.append(response) # may be useful for debugging
310311
# self.messages.append(response.assistant_message) # this is tool call
311312

312-
agent_info = bgym.AgentInfo(think=think, chat_messages=self.messages, stats={})
313+
agent_info = bgym.AgentInfo(
314+
think=think,
315+
chat_messages=self.messages,
316+
stats=self.llm.stats.stats_dict,
317+
)
313318
return action, agent_info
314319

315320

src/agentlab/llm/tracking.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
import logging
12
import os
23
import re
34
import threading
5+
from collections import defaultdict
46
from contextlib import contextmanager
7+
from dataclasses import dataclass, field
58
from functools import cache
9+
from typing import Optional
610

711
import requests
812
from langchain_community.callbacks import bedrock_anthropic_callback, openai_info
9-
from typing import Optional
10-
import logging
13+
1114

1215
TRACKER = threading.local()
1316

@@ -136,15 +139,25 @@ class TrackAPIPricingMixin:
136139
Usage: provide the pricing_api to use in the constructor.
137140
"""
138141

142+
def reset_stats(self):
143+
self.stats = Stats()
144+
139145
def __init__(self, *args, **kwargs):
140146
pricing_api = kwargs.pop("pricing_api", None)
141147
self._pricing_api = pricing_api
142148
super().__init__(*args, **kwargs)
143149
self.set_pricing_attributes()
150+
self.reset_stats()
144151

145152
def __call__(self, *args, **kwargs):
146153
"""Call the API and update the pricing tracker."""
147154
response = self._call_api(*args, **kwargs)
155+
156+
usage = dict(getattr(response, "usage", {}))
157+
usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))}
158+
usage |= {"n_api_calls": 1}
159+
self.stats.increment_stats_dict(usage)
160+
148161
self.update_pricing_tracker(response)
149162
return self._parse_response(response)
150163

@@ -215,3 +228,15 @@ def get_tokens_counts_from_response(self, response) -> tuple:
215228
"Unable to extract input and output tokens from the response. Defaulting to 0."
216229
)
217230
return 0, 0
231+
232+
233+
@dataclass
234+
class Stats:
235+
stats_dict: dict = field(default_factory=lambda: defaultdict(float))
236+
237+
def increment_stats_dict(self, stats_dict: dict):
238+
"""increment the stats_dict with the given values."""
239+
for k, v in stats_dict.items():
240+
self.stats_dict[k] += v
241+
242+

0 commit comments

Comments
 (0)