Skip to content

Commit 1ed0d1e

Browse files
authored
[BugFix] Fix failing compiled storage access (#3547)
1 parent 491ed0f commit 1ed0d1e

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

test/test_loggers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
else:
3636
TORCHVISION_VERSION = version.parse("0.0.1")
3737

38-
_has_mp4 = (_has_tv and TORCHVISION_VERSION < version.parse("0.22")) or _has_torchcodec
38+
_has_mp4 = (
39+
_has_tv and version.parse("0.20") <= TORCHVISION_VERSION < version.parse("0.22")
40+
) or _has_torchcodec
3941

4042
if _has_tb:
4143
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

torchrl/data/replay_buffers/storages.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def _storage_keys(self, value):
770770

771771
@property
772772
def _len(self):
773-
_len_value = self.__dict__.get("_len_value", None)
773+
_len_value = getattr(self, "_len_value", None)
774774
if not self._compilable:
775775
if _len_value is None:
776776
_len_value = self._len_value = mp.Value("i", 0)
@@ -783,7 +783,7 @@ def _len(self):
783783
@_len.setter
784784
def _len(self, value):
785785
if not is_compiling() and not self._compilable:
786-
_len_value = self.__dict__.get("_len_value", None)
786+
_len_value = getattr(self, "_len_value", None)
787787
if _len_value is None:
788788
_len_value = self._len_value = mp.Value("i", 0)
789789
_len_value.value = value
@@ -793,7 +793,7 @@ def _len(self, value):
793793
@property
794794
def _total_shape(self):
795795
# Total shape, irrespective of how full the storage is
796-
_total_shape = self.__dict__.get("_total_shape_value", None)
796+
_total_shape = getattr(self, "_total_shape_value", None)
797797
if _total_shape is None and self.initialized:
798798
if is_tensor_collection(self._storage):
799799
_total_shape = self._storage.shape[: self.ndim]
@@ -2440,7 +2440,7 @@ def __init__(
24402440

24412441
@property
24422442
def _len(self):
2443-
_len_value = self.__dict__.get("_len_value", None)
2443+
_len_value = getattr(self, "_len_value", None)
24442444
if not self._compilable:
24452445
if _len_value is None:
24462446
_len_value = self._len_value = mp.Value("i", 0)
@@ -2452,7 +2452,7 @@ def _len(self):
24522452
@_len.setter
24532453
def _len(self, value):
24542454
if not is_compiling() and not self._compilable:
2455-
_len_value = self.__dict__.get("_len_value", None)
2455+
_len_value = getattr(self, "_len_value", None)
24562456
if _len_value is None:
24572457
_len_value = self._len_value = mp.Value("i", 0)
24582458
_len_value.value = value

torchrl/data/replay_buffers/writers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _mark_update_entities(self, index: torch.Tensor) -> None:
236236

237237
@property
238238
def _cursor(self):
239-
_cursor_value = self.__dict__.get("_cursor_value", None)
239+
_cursor_value = getattr(self, "_cursor_value", None)
240240
if not self._compilable:
241241
if _cursor_value is None:
242242
_cursor_value = self._cursor_value = mp.Value("i", 0)
@@ -249,7 +249,7 @@ def _cursor(self):
249249
@_cursor.setter
250250
def _cursor(self, value):
251251
if not self._compilable:
252-
_cursor_value = self.__dict__.get("_cursor_value", None)
252+
_cursor_value = getattr(self, "_cursor_value", None)
253253
if _cursor_value is None:
254254
_cursor_value = self._cursor_value = mp.Value("i", 0)
255255
_cursor_value.value = value
@@ -258,7 +258,7 @@ def _cursor(self, value):
258258

259259
@property
260260
def _write_count(self):
261-
_write_count = self.__dict__.get("_write_count_value", None)
261+
_write_count = getattr(self, "_write_count_value", None)
262262
if not self._compilable:
263263
if _write_count is None:
264264
_write_count = self._write_count_value = mp.Value("i", 0)
@@ -271,7 +271,7 @@ def _write_count(self):
271271
@_write_count.setter
272272
def _write_count(self, value):
273273
if not self._compilable:
274-
_write_count = self.__dict__.get("_write_count_value", None)
274+
_write_count = getattr(self, "_write_count_value", None)
275275
if _write_count is None:
276276
_write_count = self._write_count_value = mp.Value("i", 0)
277277
_write_count.value = value
@@ -531,14 +531,14 @@ def get_insert_index(self, data: Any) -> int:
531531

532532
@property
533533
def _write_count(self):
534-
_write_count = self.__dict__.get("_write_count_value", None)
534+
_write_count = getattr(self, "_write_count_value", None)
535535
if _write_count is None:
536536
_write_count = self._write_count_value = mp.Value("i", 0)
537537
return _write_count.value
538538

539539
@_write_count.setter
540540
def _write_count(self, value):
541-
_write_count = self.__dict__.get("_write_count_value", None)
541+
_write_count = getattr(self, "_write_count_value", None)
542542
if _write_count is None:
543543
_write_count = self._write_count_value = mp.Value("i", 0)
544544
_write_count.value = value

0 commit comments

Comments
 (0)