Skip to content

Commit 369879e

Browse files
authored
Created ClearOptimizerBuffersCallback (#222)
1 parent e5f25cd commit 369879e

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

ldp/alg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .callbacks import (
99
Callback,
1010
ClearContextCallback,
11+
ClearOptimizerBuffersCallback,
1112
ComputeTrajectoryMetricsMixin,
1213
LoggingCallback,
1314
MeanMetricsCallback,
@@ -33,6 +34,7 @@
3334
"BeamSearchRollout",
3435
"Callback",
3536
"ClearContextCallback",
37+
"ClearOptimizerBuffersCallback",
3638
"ComputeTrajectoryMetricsMixin",
3739
"Evaluator",
3840
"EvaluatorConfig",

ldp/alg/callbacks.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77
from collections.abc import Callable, Collection, Iterable, Sequence
88
from pathlib import Path
9-
from typing import Any, cast
9+
from typing import TYPE_CHECKING, Any, cast
1010

1111
import aiofiles
1212
from aviary.core import (
@@ -27,6 +27,9 @@
2727
except ImportError:
2828
wandb = None # type: ignore[assignment]
2929

30+
if TYPE_CHECKING:
31+
from ldp.alg.optimizer.replay_buffers import ReplayBuffer
32+
3033
logger = logging.getLogger(__name__)
3134

3235

@@ -567,3 +570,14 @@ async def after_env_step(
567570
print("\nObservation:")
568571
pprint(obs, expand_all=True)
569572
print(f"Elapsed time: {elapsed_time:.2f} seconds")
573+
574+
575+
class ClearOptimizerBuffersCallback(Callback):
576+
"""Invoke the clear method on buffer(s) after each optimizer update."""
577+
578+
def __init__(self, *buffers: "ReplayBuffer"):
579+
self._buffers = list(buffers)
580+
581+
async def after_update(self) -> None:
582+
for b in self._buffers:
583+
b.clear()

tests/test_buffers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
@pytest.mark.asyncio
10-
async def test_circular_buffer():
10+
async def test_circular_buffer() -> None:
1111
buf = CircularReplayBuffer()
1212

1313
samples = [{"state": 1, "action": 2, "reward": 3, "t": t} for t in range(5)]
@@ -27,6 +27,9 @@ async def test_circular_buffer():
2727
):
2828
next(buf.batched_iter(batch_size=4))
2929

30+
buf.clear()
31+
assert not buf, "Failed to clear data"
32+
3033

3134
async def _dummy_q_function(*args, **kwargs) -> float: # noqa: ARG001
3235
return 1.0

0 commit comments

Comments
 (0)