Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1f6f327
Fixed MultiSyncCollector set_seed and split_trajs issue
ParamThakkar123 Jan 19, 2026
e2aaf6b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 20, 2026
40642d5
Revert "Fixed MultiSyncCollector set_seed and split_trajs issue"
ParamThakkar123 Jan 20, 2026
efdc89c
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 21, 2026
628f44b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 23, 2026
a476a77
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 24, 2026
0f565c5
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 25, 2026
7fb086b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 26, 2026
ff72793
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 28, 2026
69001ed
Added Support for index_select in TensorSpec
ParamThakkar123 Jan 28, 2026
4ab13be
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 29, 2026
2e8face
rebase
ParamThakkar123 Jan 29, 2026
56e1529
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 31, 2026
ba6a19f
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 4, 2026
8be545b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 5, 2026
54abe29
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 8, 2026
78dd00a
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 12, 2026
94fe080
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 13, 2026
524a4de
Added Lazy implementation of priority updates for replaybuffer prototype
ParamThakkar123 Feb 13, 2026
5cc4e78
Updates
ParamThakkar123 Feb 13, 2026
592aa29
[PrioritizedSampler] Clean up lazy mark_update implementation
vmoens Feb 14, 2026
366bf83
Merge branch 'main' of https://github.com/pytorch/rl into add/lazy
ParamThakkar123 Feb 17, 2026
62cde31
Merge branch 'add/lazy' of https://github.com/ParamThakkar123/rl into…
ParamThakkar123 Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2077,6 +2077,84 @@ def test_shared_storage_prioritized_sampler():
assert rb1._sampler._sum_tree.query(0, 70) == 50


def test_prioritized_sampler_mark_update_is_lazy():
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
batch_size=2,
)
rb.extend(torch.arange(4))

# Priorities are lazily materialized from mark_update.
assert rb._sampler._sum_tree.query(0, 4) == pytest.approx(0.0)

rb.sample()
assert rb._sampler._sum_tree.query(0, 4) == pytest.approx(4.0)

rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
batch_size=2,
)
idx = rb.extend(torch.arange(4))
rb.update_priority(idx, 2)
assert rb._sampler._sum_tree.query(0, 4) == pytest.approx(8.0)


def test_prioritized_sampler_lazy_multiple_extends():
"""Multiple extends accumulate pending updates; all are flushed on sample."""
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
batch_size=2,
)
rb.extend(torch.arange(3))
rb.extend(torch.arange(3) + 3)

# Nothing written yet.
assert rb._sampler._sum_tree.query(0, 6) == pytest.approx(0.0)

rb.sample()
# All 6 indices should now have default priority (~1.0 each).
assert rb._sampler._sum_tree.query(0, 6) == pytest.approx(6.0, abs=1e-4)


def test_prioritized_sampler_lazy_empty_clears_pending():
"""Calling _empty on the buffer must discard pending updates."""
sampler = PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=sampler,
batch_size=2,
)
rb.extend(torch.arange(4))
assert len(sampler._pending_updates) > 0

rb.empty()
assert len(sampler._pending_updates) == 0
assert sampler._sum_tree.query(0, 10) == pytest.approx(0.0)


def test_prioritized_sampler_lazy_state_dict_roundtrip():
"""state_dict flushes pending updates; load_state_dict clears them."""
rb = ReplayBuffer(
storage=LazyTensorStorage(10),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
batch_size=2,
)
rb.extend(torch.arange(4))

# state_dict should flush before saving.
sd = rb._sampler.state_dict()
assert len(rb._sampler._pending_updates) == 0

# Load into a fresh sampler.
sampler2 = PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0)
sampler2.load_state_dict(sd)
assert len(sampler2._pending_updates) == 0
assert sampler2._sum_tree.query(0, 4) == pytest.approx(4.0, abs=1e-4)


class TestTransforms:
def test_append_transform(self):
rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), batch_size=1)
Expand Down Expand Up @@ -3360,6 +3438,7 @@ def test_prb_ndim(self):
)
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
idx = rb.extend(data)
rb.sample()
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all()
Expand All @@ -3378,6 +3457,7 @@ def test_prb_ndim(self):
)
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
idx = rb.extend(data)
rb.sample()
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all()
Expand Down Expand Up @@ -3423,6 +3503,7 @@ def test_prb_ndim(self):
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
)
idx = rb.extend(data)
rb.sample()
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all()
Expand All @@ -3445,6 +3526,7 @@ def test_prb_ndim(self):
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
)
idx = rb.extend(data)
rb.sample()
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all()
rb.update_priority(idx, 2)
assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all()
Expand Down
181 changes: 120 additions & 61 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def _init(self) -> None:
f"dtype {self.dtype} not supported by PrioritizedSampler"
)
self._max_priority = None
self._pending_updates = []

def _empty(self) -> None:
self._init()
Expand Down Expand Up @@ -543,7 +544,89 @@ def default_priority(self) -> float:
mp = 1
return (mp + self._eps) ** self._alpha

def _normalize_index(
self,
index: int | torch.Tensor | tuple[torch.Tensor, ...],
*,
storage: TensorStorage | None = None,
) -> torch.Tensor:
if isinstance(index, tuple):
index = torch.stack(index, -1)
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
if _is_int(index):
return index.reshape(1)
if index.ndim > 1:
if storage is None:
raise RuntimeError(
"storage should be provided to Sampler.update_priority when the storage has more "
"than one dimension."
)
try:
shape = storage.shape
except AttributeError:
raise AttributeError(
"Could not retrieve the storage shape. If your storage is not a TensorStorage subclass "
"or its shape isn't accessible via the shape attribute, submit an issue on GitHub."
)
index = torch.as_tensor(np.ravel_multi_index(index.unbind(-1), shape))
return index.reshape(-1)

def _flush_pending_updates(self) -> None:
if not self._pending_updates:
return
pending_updates = self._pending_updates
self._pending_updates = []
for index, priority in pending_updates:
self._update_priority_tree(index, priority)

@torch.no_grad()
def _update_priority_tree(
self,
index: torch.Tensor,
priority: torch.Tensor,
) -> None:
"""Write priorities into the sum/min trees for the given flat 1-d index.

Both ``index`` and ``priority`` must already be tensors on CPU.
``index`` must be a 1-d long tensor (as returned by :meth:`_normalize_index`),
and negative indices must already be filtered out by the caller.
"""
# we need to reshape priority if it has more than one element or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
try:
priority = priority.reshape(index.shape[:1])
except Exception as err:
raise RuntimeError(
"priority should be a number or an iterable of the same "
f"length as index. Got priority of shape {priority.shape} and index "
f"{index.shape}."
) from err
elif priority.numel() <= 1:
priority = priority.squeeze()

max_p, max_p_idx = priority.max(dim=0)
cur_max_priority, cur_max_priority_index = self._max_priority
if cur_max_priority is None or max_p > cur_max_priority:
cur_max_priority, cur_max_priority_index = self._max_priority = (
max_p,
index[max_p_idx] if index.ndim else index,
)
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority
if (
self._max_priority_within_buffer
and cur_max_priority_index is not None
and (index == cur_max_priority_index).any()
):
maxval, maxidx = torch.tensor(
[self._sum_tree[i] for i in range(self._max_capacity)]
).max(0)
self._max_priority = (maxval, maxidx)

def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
self._flush_pending_updates()
if len(storage) == 0:
raise RuntimeError(_EMPTY_STORAGE_ERROR)
p_sum = self._sum_tree.query(0, len(storage))
Expand Down Expand Up @@ -621,76 +704,48 @@ def update_priority(
``index.ndim > 2``.

"""
self._flush_pending_updates()
priority = torch.as_tensor(priority, device=torch.device("cpu")).detach()
index = torch.as_tensor(index, dtype=torch.long, device=torch.device("cpu"))
# we need to reshape priority if it has more than one element or if it has
# a different shape than index
if priority.numel() > 1 and priority.shape != index.shape:
try:
priority = priority.reshape(index.shape[:1])
except Exception as err:
raise RuntimeError(
"priority should be a number or an iterable of the same "
f"length as index. Got priority of shape {priority.shape} and index "
f"{index.shape}."
) from err
elif priority.numel() <= 1:
priority = priority.squeeze()

index = self._normalize_index(index, storage=storage)
# MaxValueWriter will set -1 for items in the data that we don't want
# to update. We therefore have to keep only the non-negative indices.
if _is_int(index):
if index == -1:
return
else:
if index.ndim > 1:
if storage is None:
raise RuntimeError(
"storage should be provided to Sampler.update_priority when the storage has more "
"than one dimension."
)
try:
shape = storage.shape
except AttributeError:
raise AttributeError(
"Could not retrieve the storage shape. If your storage is not a TensorStorage subclass "
"or its shape isn't accessible via the shape attribute, submit an issue on GitHub."
)
index = torch.as_tensor(np.ravel_multi_index(index.unbind(-1), shape))
valid_index = index >= 0
if not valid_index.any():
return
if not valid_index.all():
index = index[valid_index]
if priority.ndim:
priority = priority[valid_index]

max_p, max_p_idx = priority.max(dim=0)
cur_max_priority, cur_max_priority_index = self._max_priority
if cur_max_priority is None or max_p > cur_max_priority:
cur_max_priority, cur_max_priority_index = self._max_priority = (
max_p,
index[max_p_idx] if index.ndim else index,
)
priority = torch.pow(priority + self._eps, self._alpha)
self._sum_tree[index] = priority
self._min_tree[index] = priority
if (
self._max_priority_within_buffer
and cur_max_priority_index is not None
and (index == cur_max_priority_index).any()
):
maxval, maxidx = torch.tensor(
[self._sum_tree[i] for i in range(self._max_capacity)]
).max(0)
self._max_priority = (maxval, maxidx)
valid_index = index >= 0
if not valid_index.any():
return
if not valid_index.all():
index = index[valid_index]
if priority.ndim:
priority = priority[valid_index]
self._update_priority_tree(index, priority)

def mark_update(
self, index: int | torch.Tensor, *, storage: Storage | None = None
) -> None:
self.update_priority(index, self.default_priority, storage=storage)
"""Marks the given indices for a default-priority update.

The update is **lazy**: the priority tree is not written to immediately.
Instead the (index, default_priority) pair is appended to an internal
pending-updates list and flushed the next time the tree is read
(e.g. on :meth:`sample`, :meth:`update_priority`, :meth:`state_dict`,
or :meth:`dumps`).

If :meth:`update_priority` is called for the same indices before the
flush, the pending defaults are applied first and then overwritten by
the explicit priorities.
"""
index = self._normalize_index(index, storage=storage)
valid_index = index >= 0
if not valid_index.any():
return
if not valid_index.all():
index = index[valid_index]
priority = torch.as_tensor(
self.default_priority, device=torch.device("cpu")
).detach()
self._pending_updates.append((index.clone(), priority))

def state_dict(self) -> dict[str, Any]:
self._flush_pending_updates()
return {
"_alpha": self._alpha,
"_beta": self._beta,
Expand All @@ -707,13 +762,15 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self._max_priority = state_dict["_max_priority"]
self._sum_tree = state_dict.pop("_sum_tree")
self._min_tree = state_dict.pop("_min_tree")
self._pending_updates = []

@implement_for("torch", None, "2.5.0")
def dumps(self, path):
raise NotImplementedError("This method is not implemented for Torch < 2.5.0")

@implement_for("torch", "2.5.0", None)
def dumps(self, path): # noqa: F811
self._flush_pending_updates()
path = Path(path).absolute()
path.mkdir(exist_ok=True)
try:
Expand Down Expand Up @@ -797,6 +854,7 @@ def loads(self, path): # noqa: F811
self._sum_tree[i] = elt
for i, elt in enumerate(mm_mt.tolist()):
self._min_tree[i] = elt
self._pending_updates = []


class SliceSampler(Sampler):
Expand Down Expand Up @@ -2229,6 +2287,7 @@ def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx):
return preceding_stop_idx

def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict]:
self._flush_pending_updates()
# Sample `batch_size` indices representing the start of a slice.
# The sampling is based on a weight vector.
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
Expand Down
1 change: 0 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6828,7 +6828,6 @@ def _stack_specs(list_of_spec, dim=0, out=None):
else:
raise NotImplementedError


@TensorSpec.implements_for_spec(torch.index_select)
@Composite.implements_for_spec(torch.index_select)
def _index_select_spec(input: TensorSpec, dim: int, index: torch.Tensor) -> TensorSpec:
Expand Down