|
1 | 1 | from __future__ import annotations as _annotations |
2 | 2 |
|
3 | 3 | import dataclasses |
| 4 | +import threading |
4 | 5 | from copy import copy |
5 | 6 | from dataclasses import dataclass, fields |
6 | 7 | from typing import Annotated, Any |
@@ -189,26 +190,57 @@ class RunUsage(UsageBase): |
189 | 190 | details: dict[str, int] = dataclasses.field(default_factory=dict) |
190 | 191 | """Any extra details returned by the model.""" |
191 | 192 |
|
| 193 | + """Lock to prevent race conditions when incrementing usage from concurrent tool calls.""" |
| 194 | + |
| 195 | + @property |
| 196 | + def _lock(self) -> threading.Lock: |
| 197 | + """Get the lock, creating it if it doesn't exist.""" |
| 198 | + if not hasattr(self, '__lock'): |
| 199 | + self.__lock = threading.Lock() |
| 200 | + return self.__lock |
| 201 | + |
192 | 202 | def incr(self, incr_usage: RunUsage | RequestUsage) -> None: |
193 | 203 | """Increment the usage in place. |
194 | 204 |
|
195 | 205 | Args: |
196 | 206 | incr_usage: The usage to increment by. |
197 | 207 | """ |
198 | | - if isinstance(incr_usage, RunUsage): |
199 | | - self.requests += incr_usage.requests |
200 | | - self.tool_calls += incr_usage.tool_calls |
201 | | - return _incr_usage_tokens(self, incr_usage) |
| 208 | + with self._lock: |
| 209 | + if isinstance(incr_usage, RunUsage): |
| 210 | + self.requests += incr_usage.requests |
| 211 | + self.tool_calls += incr_usage.tool_calls |
| 212 | + return _incr_usage_tokens(self, incr_usage) |
202 | 213 |
|
203 | 214 | def __add__(self, other: RunUsage | RequestUsage) -> RunUsage: |
204 | 215 | """Add two RunUsages together. |
205 | 216 |
|
206 | 217 | This is provided so it's trivial to sum usage information from multiple runs. |
207 | 218 | """ |
208 | 219 | new_usage = copy(self) |
209 | | - new_usage.incr(other) |
| 220 | + # Note: We can't use await here since __add__ must be synchronous |
| 221 | + # But __add__ creates a new object, so there's no race condition |
| 222 | + # The race condition only happens when modifying the same object concurrently |
| 223 | + # The new instance will get its own lock via the property |
| 224 | + |
| 225 | + if isinstance(other, RunUsage): |
| 226 | + new_usage.requests += other.requests |
| 227 | + new_usage.tool_calls += other.tool_calls |
| 228 | + _incr_usage_tokens(new_usage, other) |
210 | 229 | return new_usage |
211 | 230 |
|
| 231 | + def __getstate__(self) -> dict[str, Any]: |
| 232 | + """Exclude the lock from pickling.""" |
| 233 | + state = self.__dict__.copy() |
| 234 | + # Remove any lock-related attributes since they can't be pickled |
| 235 | + state.pop('_lock', None) |
| 236 | + state.pop('__lock', None) |
| 237 | + return state |
| 238 | + |
| 239 | + def __setstate__(self, state: dict[str, Any]) -> None: |
| 240 | + """Restore state and create a new lock.""" |
| 241 | + self.__dict__.update(state) |
| 242 | + # The lock will be created automatically via the property |
| 243 | + |
212 | 244 |
|
213 | 245 | def _incr_usage_tokens(slf: RunUsage | RequestUsage, incr_usage: RunUsage | RequestUsage) -> None: |
214 | 246 | """Increment the usage in place. |
|
0 commit comments