Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
else:
TORCHVISION_VERSION = version.parse("0.0.1")

_has_mp4 = (_has_tv and TORCHVISION_VERSION < version.parse("0.22")) or _has_torchcodec
_has_mp4 = (
_has_tv and version.parse("0.20") <= TORCHVISION_VERSION < version.parse("0.22")
) or _has_torchcodec

if _has_tb:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
Expand Down
10 changes: 5 additions & 5 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def _storage_keys(self, value):

@property
def _len(self):
_len_value = self.__dict__.get("_len_value", None)
_len_value = getattr(self, "_len_value", None)
if not self._compilable:
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
Expand All @@ -783,7 +783,7 @@ def _len(self):
@_len.setter
def _len(self, value):
if not is_compiling() and not self._compilable:
_len_value = self.__dict__.get("_len_value", None)
_len_value = getattr(self, "_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value
Expand All @@ -793,7 +793,7 @@ def _len(self, value):
@property
def _total_shape(self):
# Total shape, irrespective of how full the storage is
_total_shape = self.__dict__.get("_total_shape_value", None)
_total_shape = getattr(self, "_total_shape_value", None)
if _total_shape is None and self.initialized:
if is_tensor_collection(self._storage):
_total_shape = self._storage.shape[: self.ndim]
Expand Down Expand Up @@ -2440,7 +2440,7 @@ def __init__(

@property
def _len(self):
_len_value = self.__dict__.get("_len_value", None)
_len_value = getattr(self, "_len_value", None)
if not self._compilable:
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
Expand All @@ -2452,7 +2452,7 @@ def _len(self):
@_len.setter
def _len(self, value):
if not is_compiling() and not self._compilable:
_len_value = self.__dict__.get("_len_value", None)
_len_value = getattr(self, "_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value
Expand Down
12 changes: 6 additions & 6 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _mark_update_entities(self, index: torch.Tensor) -> None:

@property
def _cursor(self):
_cursor_value = self.__dict__.get("_cursor_value", None)
_cursor_value = getattr(self, "_cursor_value", None)
if not self._compilable:
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
Expand All @@ -249,7 +249,7 @@ def _cursor(self):
@_cursor.setter
def _cursor(self, value):
if not self._compilable:
_cursor_value = self.__dict__.get("_cursor_value", None)
_cursor_value = getattr(self, "_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value
Expand All @@ -258,7 +258,7 @@ def _cursor(self, value):

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
_write_count = getattr(self, "_write_count_value", None)
if not self._compilable:
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
Expand All @@ -271,7 +271,7 @@ def _write_count(self):
@_write_count.setter
def _write_count(self, value):
if not self._compilable:
_write_count = self.__dict__.get("_write_count_value", None)
_write_count = getattr(self, "_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value
Expand Down Expand Up @@ -531,14 +531,14 @@ def get_insert_index(self, data: Any) -> int:

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
_write_count = getattr(self, "_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value

@_write_count.setter
def _write_count(self, value):
_write_count = self.__dict__.get("_write_count_value", None)
_write_count = getattr(self, "_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value
Expand Down
Loading