-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathmemory.py
More file actions
458 lines (388 loc) · 19.5 KB
/
memory.py
File metadata and controls
458 lines (388 loc) · 19.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import asyncio
import logging
import sys
from collections.abc import Iterable
from collections.abc import Mapping as MappingABC
from typing import (
Any,
Callable,
Counter,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
)
import aiologic
from pydantic import BaseModel
from agentlightning.types import AttemptedRollout, NamedResources, PaginatedResult, ResourcesUpdate, Rollout, Span
from agentlightning.utils.metrics import MetricsBackend
from .base import UNSET, LightningStoreCapabilities, LightningStoreStatistics, Unset, is_finished, is_running
from .collection import InMemoryLightningCollections
from .collection_based import CollectionBasedLightningStore, tracked
T_callable = TypeVar("T_callable", bound=Callable[..., Any])
logger = logging.getLogger(__name__)
def estimate_model_size(obj: Any) -> int:
"""Rough recursive size estimate for Pydantic BaseModel instances."""
if isinstance(obj, BaseModel):
values = cast(Iterable[Any], obj.__dict__.values())
return sum(estimate_model_size(value) for value in values) + sys.getsizeof(cast(object, obj))
if isinstance(obj, MappingABC):
mapping = cast(Mapping[Any, Any], obj)
return sum(estimate_model_size(value) for value in mapping.values()) + sys.getsizeof(cast(object, obj))
if isinstance(obj, (list, tuple, set)):
iterable = cast(Iterable[Any], obj)
return sum(estimate_model_size(value) for value in iterable) + sys.getsizeof(cast(object, obj))
return sys.getsizeof(cast(object, obj))
def _detect_total_memory_bytes() -> int:
"""Best-effort detection of the total available system memory in bytes."""
try:
import psutil
return int(psutil.virtual_memory().total)
except ImportError:
# Fallback to 8GB if memory cannot be detected.
logger.error("psutil is not installed. Falling back to 8GB of memory in total.")
return 8 * 1024**3
class InMemoryLightningStore(CollectionBasedLightningStore[InMemoryLightningCollections]):
"""
In-memory implementation of LightningStore using Python data structures.
Thread-safe and async-compatible but data is not persistent.
Args:
thread_safe: Whether the store is thread-safe.
eviction_memory_threshold: The threshold for evicting spans in bytes.
By default, it's 70% of the total VRAM available.
safe_memory_threshold: The threshold for safe memory usage in bytes.
By default, it's 80% of the eviction threshold.
span_size_estimator: A function to estimate the size of a span in bytes.
By default, it's a simple size estimator that uses sys.getsizeof.
tracker: The metrics tracker to use.
scan_debounce_seconds: The debounce time for the scan for unhealthy rollouts.
Set to 0 to disable debouncing.
"""
def __init__(
self,
*,
thread_safe: bool = False,
eviction_memory_threshold: float | int | None = None,
safe_memory_threshold: float | int | None = None,
span_size_estimator: Callable[[Span], int] | None = None,
tracker: MetricsBackend | None = None,
scan_debounce_seconds: float = 10.0,
):
super().__init__(
collections=InMemoryLightningCollections(lock_type="thread" if thread_safe else "asyncio", tracker=tracker),
tracker=tracker,
scan_debounce_seconds=scan_debounce_seconds,
)
self._thread_safe = thread_safe
self._start_time_by_rollout: Dict[str, float] = {}
self._span_bytes_by_rollout: Dict[str, int] = Counter()
self._total_span_bytes: int = 0
self._evicted_rollout_span_sets: Set[str] = set()
self._memory_capacity_bytes = _detect_total_memory_bytes()
if self._memory_capacity_bytes <= 0:
raise ValueError("Detected memory capacity must be positive")
self._eviction_threshold_bytes = self._resolve_memory_threshold(
eviction_memory_threshold,
default_ratio=0.7,
capacity_bytes=self._memory_capacity_bytes,
name="eviction_memory_threshold",
minimum=1,
)
if safe_memory_threshold is None:
safe_memory_threshold = max(int(self._eviction_threshold_bytes * 0.8), 0)
self._safe_threshold_bytes = self._resolve_memory_threshold(
safe_memory_threshold,
default_ratio=self._eviction_threshold_bytes / self._memory_capacity_bytes,
capacity_bytes=self._memory_capacity_bytes,
name="safe_memory_threshold",
minimum=0,
)
if not (0 <= self._safe_threshold_bytes < self._eviction_threshold_bytes):
raise ValueError("safe_memory_threshold must be smaller than eviction_memory_threshold")
self._custom_span_size_estimator = span_size_estimator
# Completion tracking for wait_for_rollouts (cross-loop safe)
self._completion_events: Dict[str, aiologic.Event] = {}
# Running rollouts cache, including preparing and running rollouts
self._running_rollout_ids: Set[str] = set()
# Caches the latest resources ID.
self._latest_resources_id: Union[str, None, Unset] = UNSET
@property
def capabilities(self) -> LightningStoreCapabilities:
"""Return the capabilities of the store."""
return LightningStoreCapabilities(
thread_safe=self._thread_safe,
async_safe=True,
zero_copy=False,
otlp_traces=False,
)
async def statistics(self) -> LightningStoreStatistics:
"""Return the statistics of the store."""
return {
**(await super().statistics()),
"total_span_bytes": self._total_span_bytes,
"eviction_threshold_bytes": self._eviction_threshold_bytes,
"safe_threshold_bytes": self._safe_threshold_bytes,
"memory_capacity_bytes": self._memory_capacity_bytes,
}
@tracked("wait_for_rollout")
async def wait_for_rollout(self, rollout_id: str, timeout: Optional[float] = None) -> Optional[Rollout]:
"""Wait for a specific rollout to complete with a timeout."""
async with self.collections.atomic(mode="r", snapshot=self._read_snapshot, labels=["rollouts"]) as collections:
rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
if rollout and is_finished(rollout):
return rollout
if timeout is not None and timeout <= 0:
return None
# If not completed and we have an event, wait for completion
if rollout_id in self._completion_events:
evt = self._completion_events[rollout_id]
# Wait for the event with proper timeout handling
# evt.wait() returns True if event was set, False if timeout occurred
if timeout is None:
# Wait indefinitely by polling with finite timeouts
# This allows threads to exit cleanly on shutdown
while True:
result = await asyncio.to_thread(evt.wait, 10.0) # Poll every 10 seconds
if result: # Event was set
break
# Loop and check again (continues indefinitely since timeout=None)
else:
# Wait with the specified timeout
result = await asyncio.to_thread(evt.wait, timeout)
# If event was set (not timeout), check if rollout is finished
if result:
async with self.collections.atomic(
mode="r", snapshot=self._read_snapshot, labels=["rollouts"]
) as collections:
rollout = await collections.rollouts.get({"rollout_id": {"exact": rollout_id}})
if rollout and is_finished(rollout):
return rollout
return None
@tracked("add_resources_inmemory")
async def add_resources(self, resources: NamedResources) -> ResourcesUpdate:
ret = await super().add_resources(resources)
async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["resources"]):
self._latest_resources_id = ret.resources_id
return ret
@tracked("update_resources_inmemory")
async def update_resources(self, resources_id: str, resources: NamedResources) -> ResourcesUpdate:
ret = await super().update_resources(resources_id, resources)
async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["resources"]):
self._latest_resources_id = ret.resources_id
return ret
@tracked("_post_update_rollout_inmemory")
async def _post_update_rollout(
self, rollouts: Sequence[Tuple[Rollout, Sequence[str]]], skip_enqueue: bool = False
) -> None:
"""Update the running rollout ids set when the rollout updates."""
await super()._post_update_rollout(rollouts, skip_enqueue=skip_enqueue)
async with self.collections.atomic(mode="rw", snapshot=self._read_snapshot, labels=["rollouts"]):
for rollout, _ in rollouts:
if is_running(rollout):
self._running_rollout_ids.add(rollout.rollout_id)
else:
self._running_rollout_ids.discard(rollout.rollout_id)
if is_finished(rollout):
self._completion_events.setdefault(rollout.rollout_id, aiologic.Event())
self._completion_events[rollout.rollout_id].set()
else:
self._completion_events.setdefault(rollout.rollout_id, aiologic.Event())
# Rollout status can never transition from finished to running (unlike attempt)
# so we don't need to clear the completion event even in case of retrying.
if rollout.rollout_id not in self._start_time_by_rollout:
self._start_time_by_rollout[rollout.rollout_id] = rollout.start_time
@tracked("_unlocked_query_rollouts_by_rollout_ids")
async def _unlocked_query_rollouts_by_rollout_ids(
self, collections: InMemoryLightningCollections, rollout_ids: Sequence[str]
) -> List[Rollout]:
"""Always use exact. This is faster than within filter for in-memory store."""
if len(rollout_ids) == 0:
return []
rollouts = [await collections.rollouts.get({"rollout_id": {"exact": rollout_id}}) for rollout_id in rollout_ids]
return [rollout for rollout in rollouts if rollout is not None]
@tracked("_unlocked_get_running_rollouts")
async def _unlocked_get_running_rollouts(self, collections: InMemoryLightningCollections) -> List[AttemptedRollout]:
"""Accelerated version of `_unlocked_get_running_rollouts` for in-memory store. Used for healthcheck."""
async with self.collections.atomic(
mode="r", snapshot=self._read_snapshot, labels=["rollouts", "attempts"]
) as collections:
rollouts = await self._unlocked_query_rollouts_by_rollout_ids(collections, list(self._running_rollout_ids))
running_rollouts: List[AttemptedRollout] = []
for rollout in rollouts:
latest_attempt = await collections.attempts.get(
filter={"rollout_id": {"exact": rollout.rollout_id}},
sort={"name": "sequence_id", "order": "desc"},
)
if not latest_attempt:
# The rollout is running but has no attempts, this should not happen
logger.error(f"Rollout {rollout.rollout_id} is running but has no attempts")
continue
running_rollouts.append(AttemptedRollout(**rollout.model_dump(), attempt=latest_attempt))
return running_rollouts
@tracked("query_spans_inmemory") # Since this method calls super, we need to track it separately
async def query_spans(
self,
rollout_id: str,
attempt_id: str | Literal["latest"] | None = None,
**kwargs: Any,
) -> PaginatedResult[Span]:
if rollout_id in self._evicted_rollout_span_sets:
raise RuntimeError(f"Spans for rollout {rollout_id} have been evicted")
return await super().query_spans(rollout_id, attempt_id, **kwargs)
@tracked("_post_add_spans")
async def _post_add_spans(self, spans: Sequence[Span], rollout_id: str, attempt_id: str) -> None:
"""In-memory store needs to maintain the span data in memory, and evict spans when memory is low."""
await super()._post_add_spans(spans, rollout_id, attempt_id)
async with self.collections.atomic(
mode="rw", snapshot=self._read_snapshot, labels=["rollouts", "spans"]
) as collections:
for span in spans:
await self._account_span_size(span)
await self._maybe_evict_spans(collections)
@tracked("_get_latest_resources_inmemory")
async def _get_latest_resources(self) -> Optional[ResourcesUpdate]:
if isinstance(self._latest_resources_id, Unset):
return await super()._get_latest_resources()
if self._latest_resources_id is not None:
async with self.collections.atomic(
mode="r", snapshot=self._read_snapshot, labels=["resources"]
) as collections:
return await collections.resources.get(filter={"resources_id": {"exact": self._latest_resources_id}})
return None
@staticmethod
def _resolve_memory_threshold(
value: float | int | None,
*,
default_ratio: float,
capacity_bytes: int,
name: str,
minimum: int,
) -> int:
if value is None:
resolved = int(capacity_bytes * default_ratio)
elif isinstance(value, float):
if minimum == 0:
if not (0 <= value <= 1):
raise ValueError(f"{name} ratio must be between 0 and 1 inclusive")
else:
if not (0 < value <= 1):
raise ValueError(f"{name} ratio must be greater than 0 and at most 1")
resolved = int(capacity_bytes * value)
else:
value_int = value
if value_int < 0:
raise ValueError(f"{name} must be non-negative")
resolved = value_int
if resolved < minimum:
raise ValueError(f"{name} must be at least {minimum} bytes")
return resolved
@tracked("_account_span_size")
async def _account_span_size(self, span: Span) -> int:
if self._custom_span_size_estimator is not None:
size = max(int(self._custom_span_size_estimator(span)), 0)
else:
size = estimate_model_size(span)
self._span_bytes_by_rollout[span.rollout_id] += size
self._total_span_bytes += size
return size
@tracked("_maybe_evict_spans")
async def _maybe_evict_spans(self, collections: InMemoryLightningCollections) -> None:
if self._total_span_bytes <= self._eviction_threshold_bytes:
return
logger.info(
f"Total span bytes: {self._total_span_bytes}, eviction threshold: {self._eviction_threshold_bytes}, "
f"safe threshold: {self._safe_threshold_bytes}. Evicting spans..."
)
candidates: List[tuple[float, str]] = [
(start_time, rollout_id) for rollout_id, start_time in self._start_time_by_rollout.items()
]
candidates.sort()
logger.info(f"Evicting spans for {len(candidates)} rollouts to free up memory...")
memory_consumed_before = self._total_span_bytes
for _, rollout_id in candidates:
if self._total_span_bytes <= self._safe_threshold_bytes:
break
logger.debug(f"Evicting spans for rollout {rollout_id} to free up memory...")
await self._evict_spans_for_rollout(collections, rollout_id)
logger.info(f"Freed up {memory_consumed_before - self._total_span_bytes} bytes of memory")
@tracked("_evict_spans_for_rollout")
async def _evict_spans_for_rollout(self, collections: InMemoryLightningCollections, rollout_id: str) -> None:
await collections.evict_spans_for_rollout(rollout_id)
removed_bytes = self._span_bytes_by_rollout.pop(rollout_id, 0)
if removed_bytes > 0:
# There is something removed for real
self._total_span_bytes = max(self._total_span_bytes - removed_bytes, 0)
self._evicted_rollout_span_sets.add(rollout_id)
@tracked("cleanup_finished_rollouts")
async def cleanup_finished_rollouts(self, rollout_ids=None):
"""Remove all data associated with finished rollouts to free memory.
This should be called after training data has been extracted from completed
rollouts (e.g., after get_train_data_batch or get_test_metrics). It removes
rollouts, their attempts, spans, and associated tracking metadata from all
in-memory data structures.
Args:
rollout_ids: Optional list of rollout IDs to clean up. If None, all
finished rollouts will be cleaned up.
Returns:
The number of rollouts cleaned up.
"""
cleaned_count = 0
async with self.collections.atomic(
mode="rw", snapshot=self._read_snapshot,
labels=["rollouts", "attempts", "spans", "span_sequence_ids"],
) as collections:
# Determine which rollouts to clean up
if rollout_ids is None:
all_rollouts = await collections.rollouts.query()
target_ids = [
r.rollout_id for r in all_rollouts.items if is_finished(r)
]
else:
target_ids = list(rollout_ids)
for rollout_id in target_ids:
rollout = await collections.rollouts.get(
{"rollout_id": {"exact": rollout_id}}
)
if rollout is None:
continue
if not is_finished(rollout):
continue
# Remove spans for this rollout
await collections.evict_spans_for_rollout(rollout_id)
# Remove attempts for this rollout
attempts_result = await collections.attempts.query(
filter={"rollout_id": {"exact": rollout_id}}
)
if attempts_result.items:
await collections.attempts.delete(attempts_result.items)
# Remove the rollout itself
await collections.rollouts.delete([rollout])
# Remove span sequence ID tracking
await collections.span_sequence_ids.pop(rollout_id)
cleaned_count += 1
# Clean up auxiliary tracking dicts outside the collection lock
for rollout_id in target_ids:
self._completion_events.pop(rollout_id, None)
self._start_time_by_rollout.pop(rollout_id, None)
self._span_bytes_by_rollout.pop(rollout_id, None)
self._running_rollout_ids.discard(rollout_id)
self._evicted_rollout_span_sets.discard(rollout_id)
if cleaned_count > 0:
logger.info(
"Cleaned up %d finished rollouts. Completion events: %d, "
"start time entries: %d, span byte entries: %d",
cleaned_count,
len(self._completion_events),
len(self._start_time_by_rollout),
len(self._span_bytes_by_rollout),
)
return cleaned_count