Skip to content

Commit e411f7d

Browse files
authored
Adding suffix to tracker decorator (#169)
* adding suffix to decorator * moving tracker to suffix
1 parent 7e49aa7 commit e411f7d

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/agentlab/llm/tracking.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from functools import cache
21
import os
32
import threading
43
from contextlib import contextmanager
4+
from functools import cache
55

66
import requests
77
from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS
@@ -10,10 +10,13 @@
1010

1111

1212
class LLMTracker:
13-
def __init__(self):
13+
def __init__(self, suffix=""):
1414
self.input_tokens = 0
1515
self.output_tokens = 0
1616
self.cost = 0.0
17+
self.input_tokens_key = "input_tokens_" + suffix if suffix else "input_tokens"
18+
self.output_tokens_key = "output_tokens_" + suffix if suffix else "output_tokens"
19+
self.cost_key = "cost_" + suffix if suffix else "cost"
1720

1821
def __call__(self, input_tokens: int, output_tokens: int, cost: float):
1922
self.input_tokens += input_tokens
@@ -23,9 +26,9 @@ def __call__(self, input_tokens: int, output_tokens: int, cost: float):
2326
@property
2427
def stats(self):
2528
return {
26-
"input_tokens": self.input_tokens,
27-
"output_tokens": self.output_tokens,
28-
"cost": self.cost,
29+
self.input_tokens_key: self.input_tokens,
30+
self.output_tokens_key: self.output_tokens,
31+
self.cost_key: self.cost,
2932
}
3033

3134
def add_tracker(self, tracker: "LLMTracker"):
@@ -36,12 +39,12 @@ def __repr__(self):
3639

3740

3841
@contextmanager
39-
def set_tracker():
42+
def set_tracker(suffix=""):
4043
global TRACKER
4144
if not hasattr(TRACKER, "instance"):
4245
TRACKER.instance = None
4346
previous_tracker = TRACKER.instance # type: LLMTracker
44-
TRACKER.instance = LLMTracker()
47+
TRACKER.instance = LLMTracker(suffix)
4548
try:
4649
yield TRACKER.instance
4750
finally:
@@ -52,9 +55,9 @@ def set_tracker():
5255
TRACKER.instance = previous_tracker
5356

5457

55-
def cost_tracker_decorator(get_action):
58+
def cost_tracker_decorator(get_action, suffix=""):
5659
def wrapper(self, obs):
57-
with set_tracker() as tracker:
60+
with set_tracker(suffix) as tracker:
5861
action, agent_info = get_action(self, obs)
5962
agent_info.get("stats").update(tracker.stats)
6063
return action, agent_info

0 commit comments

Comments
 (0)