diff --git a/test/test_loggers.py b/test/test_loggers.py index e9a82c6fcff..65bf781dbd5 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -35,7 +35,9 @@ else: TORCHVISION_VERSION = version.parse("0.0.1") -_has_mp4 = (_has_tv and TORCHVISION_VERSION < version.parse("0.22")) or _has_torchcodec +_has_mp4 = ( + _has_tv and version.parse("0.20") <= TORCHVISION_VERSION < version.parse("0.22") +) or _has_torchcodec if _has_tb: from tensorboard.backend.event_processing.event_accumulator import EventAccumulator diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 471488d6eb3..7d263d3a428 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -770,7 +770,7 @@ def _storage_keys(self, value): @property def _len(self): - _len_value = self.__dict__.get("_len_value", None) + _len_value = getattr(self, "_len_value", None) if not self._compilable: if _len_value is None: _len_value = self._len_value = mp.Value("i", 0) @@ -783,7 +783,7 @@ def _len(self): @_len.setter def _len(self, value): if not is_compiling() and not self._compilable: - _len_value = self.__dict__.get("_len_value", None) + _len_value = getattr(self, "_len_value", None) if _len_value is None: _len_value = self._len_value = mp.Value("i", 0) _len_value.value = value @@ -793,7 +793,7 @@ def _len(self, value): @property def _total_shape(self): # Total shape, irrespective of how full the storage is - _total_shape = self.__dict__.get("_total_shape_value", None) + _total_shape = getattr(self, "_total_shape_value", None) if _total_shape is None and self.initialized: if is_tensor_collection(self._storage): _total_shape = self._storage.shape[: self.ndim] @@ -2440,7 +2440,7 @@ def __init__( @property def _len(self): - _len_value = self.__dict__.get("_len_value", None) + _len_value = getattr(self, "_len_value", None) if not self._compilable: if _len_value is None: _len_value = self._len_value = mp.Value("i", 0) @@ -2452,7 +2452,7 @@ def _len(self): @_len.setter def _len(self, value): if not is_compiling() and not self._compilable: - _len_value = self.__dict__.get("_len_value", None) + _len_value = getattr(self, "_len_value", None) if _len_value is None: _len_value = self._len_value = mp.Value("i", 0) _len_value.value = value diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index fc07580396b..183f3cf687b 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -236,7 +236,7 @@ def _mark_update_entities(self, index: torch.Tensor) -> None: @property def _cursor(self): - _cursor_value = self.__dict__.get("_cursor_value", None) + _cursor_value = getattr(self, "_cursor_value", None) if not self._compilable: if _cursor_value is None: _cursor_value = self._cursor_value = mp.Value("i", 0) @@ -249,7 +249,7 @@ def _cursor(self): @_cursor.setter def _cursor(self, value): if not self._compilable: - _cursor_value = self.__dict__.get("_cursor_value", None) + _cursor_value = getattr(self, "_cursor_value", None) if _cursor_value is None: _cursor_value = self._cursor_value = mp.Value("i", 0) _cursor_value.value = value @@ -258,7 +258,7 @@ def _cursor(self, value): @property def _write_count(self): - _write_count = self.__dict__.get("_write_count_value", None) + _write_count = getattr(self, "_write_count_value", None) if not self._compilable: if _write_count is None: _write_count = self._write_count_value = mp.Value("i", 0) @@ -271,7 +271,7 @@ def _write_count(self): @_write_count.setter def _write_count(self, value): if not self._compilable: - _write_count = self.__dict__.get("_write_count_value", None) + _write_count = getattr(self, "_write_count_value", None) if _write_count is None: _write_count = self._write_count_value = mp.Value("i", 0) _write_count.value = value @@ -531,14 +531,14 @@ def get_insert_index(self, data: Any) -> int: @property def _write_count(self): - _write_count = self.__dict__.get("_write_count_value", None) + _write_count = getattr(self, "_write_count_value", None) if _write_count is None: _write_count = self._write_count_value = mp.Value("i", 0) return _write_count.value @_write_count.setter def _write_count(self, value): - _write_count = self.__dict__.get("_write_count_value", None) + _write_count = getattr(self, "_write_count_value", None) if _write_count is None: _write_count = self._write_count_value = mp.Value("i", 0) _write_count.value = value