Skip to content

Commit 009f4ce

Browse files
authored
[Formatting] Update pre-commit (#3108)
1 parent d34dbb2 commit 009f4ce

File tree

17 files changed

+38
-38
lines changed

17 files changed

+38
-38
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.0.1
3+
rev: v5.0.0
44
hooks:
55
- id: check-docstring-first
66
- id: check-toml
@@ -11,7 +11,7 @@ repos:
1111
- id: end-of-file-fixer
1212

1313
- repo: https://github.com/omnilib/ufmt
14-
rev: v2.0.0b2
14+
rev: v2.8.0
1515
hooks:
1616
- id: ufmt
1717
additional_dependencies:
@@ -20,7 +20,7 @@ repos:
2020
- libcst == 0.4.7
2121

2222
- repo: https://github.com/pycqa/flake8
23-
rev: 6.0.0
23+
rev: 7.3.0
2424
hooks:
2525
- id: flake8
2626
args: [--config=setup.cfg]
@@ -31,13 +31,13 @@ repos:
3131
- flake8-print==5.0.0
3232

3333
- repo: https://github.com/PyCQA/pydocstyle
34-
rev: 6.1.1
34+
rev: 6.3.0
3535
hooks:
3636
- id: pydocstyle
3737
files: ^torchrl/
3838

3939
- repo: https://github.com/asottile/pyupgrade
40-
rev: v3.9.0
40+
rev: v3.20.0
4141
hooks:
4242
- id: pyupgrade
4343
args: [--py38-plus]

benchmarks/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111
from torchrl._utils import logger as torchrl_logger
1212

13-
CALL_TIMES = defaultdict(lambda: 0.0)
13+
CALL_TIMES = defaultdict(float)
1414

1515

1616
def pytest_sessionfinish(maxprint=50):

test/_utils_internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(self, old_emit):
261261
self.old_emit = old_emit
262262

263263
def __call__(self, record):
264-
nonlocal records
264+
nonlocal records # noqa: F824
265265
self.old_emit(record)
266266
if record_name in record.name:
267267
records.append(record)

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pytest
1515
import torch
1616

17-
CALL_TIMES = defaultdict(lambda: 0.0)
17+
CALL_TIMES = defaultdict(float)
1818
IS_OSX = sys.platform == "darwin"
1919

2020

test/llm/test_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2386,7 +2386,7 @@ def test_batching_continuous_throughput(
23862386
assert len(processing_events) > 0, "No processing occurred"
23872387

23882388
# Check that processing happened across multiple threads (indicating concurrent processing)
2389-
thread_ids = {event["thread_id"] for event in processing_events} # noqa
2389+
thread_ids = set(event["thread_id"] for event in processing_events)
23902390
assert (
23912391
len(thread_ids) > 1
23922392
), f"All processing happened in single thread: {thread_ids}"

test/test_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
_atari_found = True
108108
except FileNotFoundError:
109109
_atari_found = False
110-
atari_confs = defaultdict(lambda: "")
110+
atari_confs = defaultdict(str)
111111

112112
if os.getenv("PYTORCH_TEST_FBCODE"):
113113
from pytorch.rl.test._utils_internal import (

test/test_libs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3410,7 +3410,7 @@ def test_d4rl_iteration(self, task, split_trajs):
34103410

34113411
def _minari_init() -> tuple[bool, Exception | None]:
34123412
"""Initialize Minari datasets list. Returns True if already initialized."""
3413-
global _MINARI_DATASETS
3413+
global _MINARI_DATASETS # noqa: F824
34143414
if _MINARI_DATASETS and not all(
34153415
isinstance(x, str) and x.isdigit() for x in _MINARI_DATASETS
34163416
):

torchrl/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def auto_unwrap_transformed_env(allow_none=False):
11761176
bool or None: The current setting for automatically unwrapping TransformedEnv
11771177
instances.
11781178
"""
1179-
global _AUTO_UNWRAP
1179+
global _AUTO_UNWRAP # noqa: F824
11801180
if _AUTO_UNWRAP is None and allow_none:
11811181
return None
11821182
elif _AUTO_UNWRAP is None:

torchrl/data/llm/history.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def add_chat_template(
246246
`name_or_path` attribute.
247247
- Templates are stored globally and persist for the duration of the Python session.
248248
"""
249-
global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS
249+
global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS # noqa: F824
250250

251251
# Validate template contains generation blocks
252252
if "{% generation %}" not in template:

torchrl/data/tensor_specs.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def __eq__(self, other):
523523
return True
524524
return False
525525
return (
526-
type(self) == type(other)
526+
type(self) is type(other)
527527
and self.low.dtype == other.low.dtype
528528
and self.high.dtype == other.high.dtype
529529
and self.device == other.device
@@ -627,7 +627,7 @@ class TensorSpec(metaclass=abc.ABCMeta):
627627

628628
shape: torch.Size
629629
space: None | Box
630-
device: torch.device | None = None
630+
device: torch.device | None = None # noqa # type: ignore
631631
dtype: torch.dtype = torch.get_default_dtype()
632632
domain: str = ""
633633
_encode_memo_dict: dict[Any, Callable[[Any], Any]] = field(
@@ -679,7 +679,7 @@ def decorator(func):
679679
return decorator
680680

681681
@property
682-
def device(self) -> torch.device:
682+
def device(self) -> torch.device: # noqa # type: ignore
683683
"""The device of the spec.
684684
685685
Only :class:`Composite` specs can have a ``None`` device. All leaves must have a non-null device.
@@ -2163,7 +2163,7 @@ def __eq__(self, other):
21632163
and (self.mask == other.mask).all()
21642164
)
21652165
return (
2166-
type(self) == type(other)
2166+
type(self) is type(other)
21672167
and self.shape == other.shape
21682168
and self.space == other.space
21692169
and self.device == other.device
@@ -2444,7 +2444,7 @@ def cardinality(self) -> int:
24442444

24452445
def __eq__(self, other):
24462446
return (
2447-
type(other) == type(self)
2447+
type(other) is type(self)
24482448
and self.device == other.device
24492449
and self.shape == other.shape
24502450
and self.space == other.space
@@ -3443,7 +3443,7 @@ def __eq__(self, other):
34433443
and (self.mask == other.mask).all()
34443444
)
34453445
return (
3446-
type(self) == type(other)
3446+
type(self) is type(other)
34473447
and self.shape == other.shape
34483448
and self.space == other.space
34493449
and self.device == other.device
@@ -4042,7 +4042,7 @@ def __eq__(self, other):
40424042
and (self.mask == other.mask).all()
40434043
)
40444044
return (
4045-
type(self) == type(other)
4045+
type(self) is type(other)
40464046
and self.shape == other.shape
40474047
and self.space == other.space
40484048
and self.device == other.device
@@ -4731,7 +4731,7 @@ def __eq__(self, other):
47314731
and (self.mask == other.mask).all()
47324732
)
47334733
return (
4734-
type(self) == type(other)
4734+
type(self) is type(other)
47354735
and self.shape == other.shape
47364736
and self.space == other.space
47374737
and self.device == other.device
@@ -5585,7 +5585,7 @@ def is_in(self, val: dict | TensorDictBase) -> bool:
55855585
# return False
55865586
# if val.shape[-self.ndim:] != self.shape:
55875587
# return False
5588-
if self.data_cls is not None and type(val) != self.data_cls:
5588+
if self.data_cls is not None and type(val) is not self.data_cls:
55895589
return False
55905590
for key, item in self._specs.items():
55915591
if item is None or (isinstance(item, Composite) and item.is_empty()):
@@ -5902,7 +5902,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
59025902

59035903
def __eq__(self, other: object) -> bool:
59045904
return (
5905-
type(self) == type(other)
5905+
type(self) is type(other)
59065906
and self.shape == other.shape
59075907
and self._device == other._device
59085908
and set(self._specs.keys()) == set(other._specs.keys())

0 commit comments

Comments
 (0)