|
| 1 | +import logging |
1 | 2 | import os |
2 | 3 | import re |
3 | 4 | import threading |
| 5 | +from collections import defaultdict |
4 | 6 | from contextlib import contextmanager |
| 7 | +from dataclasses import dataclass, field |
5 | 8 | from functools import cache |
| 9 | +from typing import Optional |
6 | 10 |
|
7 | 11 | import requests |
8 | 12 | from langchain_community.callbacks import bedrock_anthropic_callback, openai_info |
9 | | -from typing import Optional |
10 | | -import logging |
| 13 | + |
11 | 14 |
|
12 | 15 | TRACKER = threading.local() |
13 | 16 |
|
@@ -136,15 +139,25 @@ class TrackAPIPricingMixin: |
136 | 139 | Usage: provide the pricing_api to use in the constructor. |
137 | 140 | """ |
138 | 141 |
|
| 142 | + def reset_stats(self): |
| 143 | + self.stats = Stats() |
| 144 | + |
139 | 145 | def __init__(self, *args, **kwargs): |
140 | 146 | pricing_api = kwargs.pop("pricing_api", None) |
141 | 147 | self._pricing_api = pricing_api |
142 | 148 | super().__init__(*args, **kwargs) |
143 | 149 | self.set_pricing_attributes() |
| 150 | + self.reset_stats() |
144 | 151 |
|
145 | 152 | def __call__(self, *args, **kwargs): |
146 | 153 | """Call the API and update the pricing tracker.""" |
147 | 154 | response = self._call_api(*args, **kwargs) |
| 155 | + |
| 156 | + usage = dict(getattr(response, "usage", {})) |
| 157 | + usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))} |
| 158 | + usage |= {"n_api_calls": 1} |
| 159 | + self.stats.increment_stats_dict(usage) |
| 160 | + |
148 | 161 | self.update_pricing_tracker(response) |
149 | 162 | return self._parse_response(response) |
150 | 163 |
|
@@ -215,3 +228,15 @@ def get_tokens_counts_from_response(self, response) -> tuple: |
215 | 228 | "Unable to extract input and output tokens from the response. Defaulting to 0." |
216 | 229 | ) |
217 | 230 | return 0, 0 |
| 231 | + |
| 232 | + |
| 233 | +@dataclass |
| 234 | +class Stats: |
| 235 | + stats_dict: dict = field(default_factory=lambda: defaultdict(float)) |
| 236 | + |
| 237 | + def increment_stats_dict(self, stats_dict: dict): |
| 238 | + """increment the stats_dict with the given values.""" |
| 239 | + for k, v in stats_dict.items(): |
| 240 | + self.stats_dict[k] += v |
| 241 | + |
| 242 | + |
0 commit comments