Skip to content

Commit c4c2ce0

Browse files
authored
Update pre-commits (#733)
1 parent 2cb0bf5 commit c4c2ce0

File tree

16 files changed

+69
-69
lines changed

16 files changed

+69
-69
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,25 @@ repos:
1414
- id: end-of-file-fixer
1515
- id: trailing-whitespace
1616
- repo: https://github.com/asottile/pyupgrade
17-
rev: v3.19.0
17+
rev: v3.19.1
1818
hooks:
1919
- id: pyupgrade
2020
- repo: https://github.com/astral-sh/ruff-pre-commit
21-
rev: v0.8.2
21+
rev: v0.9.6
2222
hooks:
2323
- id: ruff
2424
args: [--fix]
2525
- id: ruff-format
2626
- repo: https://github.com/python-poetry/poetry
27-
rev: 1.8.0
27+
rev: 1.8.5
2828
hooks:
2929
- id: poetry-check
3030
- id: poetry-lock
3131
args:
3232
- "--check"
3333
- "--no-update"
3434
- repo: https://github.com/gitleaks/gitleaks
35-
rev: v8.21.2
35+
rev: v8.23.3
3636
hooks:
3737
- id: gitleaks
3838
- repo: https://github.com/woodruffw/zizmor-pre-commit

lerobot/common/datasets/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
104104
)
105105
logging.info(
106106
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
107-
f"{pformat(dataset.repo_id_to_index , indent=2)}"
107+
f"{pformat(dataset.repo_id_to_index, indent=2)}"
108108
)
109109

110110
if cfg.dataset.use_imagenet_stats:

lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
7272
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
7373
# are too far appart.
7474
direction="nearest",
75-
tolerance=pd.Timedelta(f"{1/fps} seconds"),
75+
tolerance=pd.Timedelta(f"{1 / fps} seconds"),
7676
)
7777
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
7878
df = df[df["episode_index"] != -1]

lerobot/common/policies/act/modeling_act.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
409409
latent dimension.
410410
"""
411411
if self.config.use_vae and self.training:
412-
assert (
413-
"action" in batch
414-
), "actions must be provided when using the variational objective in training mode."
412+
assert "action" in batch, (
413+
"actions must be provided when using the variational objective in training mode."
414+
)
415415

416416
batch_size = (
417417
batch["observation.images"]

lerobot/common/policies/diffusion/configuration_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def validate_features(self) -> None:
221221
for key, image_ft in self.image_features.items():
222222
if image_ft.shape != first_image_ft.shape:
223223
raise ValueError(
224-
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
224+
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
225225
)
226226

227227
@property

lerobot/common/policies/tdmpc/modeling_tdmpc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,9 @@ def _apply_fn(m):
594594

595595
self.apply(_apply_fn)
596596
for m in [self._reward, *self._Qs]:
597-
assert isinstance(
598-
m[-1], nn.Linear
599-
), "Sanity check. The last linear layer needs 0 initialization on weights."
597+
assert isinstance(m[-1], nn.Linear), (
598+
"Sanity check. The last linear layer needs 0 initialization on weights."
599+
)
600600
nn.init.zeros_(m[-1].weight)
601601
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
602602

lerobot/common/policies/vqbet/configuration_vqbet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def validate_features(self) -> None:
184184
for key, image_ft in self.image_features.items():
185185
if image_ft.shape != first_image_ft.shape:
186186
raise ValueError(
187-
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
187+
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
188188
)
189189

190190
@property

lerobot/common/policies/vqbet/vqbet_utils.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def __init__(self, config: VQBeTConfig):
203203
def forward(self, input, targets=None):
204204
device = input.device
205205
b, t, d = input.size()
206-
assert (
207-
t <= self.config.gpt_block_size
208-
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
206+
assert t <= self.config.gpt_block_size, (
207+
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
208+
)
209209

210210
# positional encodings that are added to the input embeddings
211211
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
@@ -273,10 +273,10 @@ def configure_parameters(self):
273273
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
274274
str(inter_params)
275275
)
276-
assert (
277-
len(param_dict.keys() - union_params) == 0
278-
), "parameters {} were not separated into either decay/no_decay set!".format(
279-
str(param_dict.keys() - union_params),
276+
assert len(param_dict.keys() - union_params) == 0, (
277+
"parameters {} were not separated into either decay/no_decay set!".format(
278+
str(param_dict.keys() - union_params),
279+
)
280280
)
281281

282282
decay = [param_dict[pn] for pn in sorted(decay)]
@@ -419,9 +419,9 @@ def get_codebook_vector_from_indices(self, indices):
419419
# and the network should be able to reconstruct
420420

421421
if quantize_dim < self.num_quantizers:
422-
assert (
423-
self.quantize_dropout > 0.0
424-
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
422+
assert self.quantize_dropout > 0.0, (
423+
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
424+
)
425425
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
426426

427427
# get ready for gathering
@@ -472,9 +472,9 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=
472472
all_indices = []
473473

474474
if return_loss:
475-
assert not torch.any(
476-
indices == -1
477-
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
475+
assert not torch.any(indices == -1), (
476+
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
477+
)
478478
ce_losses = []
479479

480480
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
@@ -887,9 +887,9 @@ def calculate_ce_loss(codes):
887887
# only calculate orthogonal loss for the activated codes for this batch
888888

889889
if self.orthogonal_reg_active_codes_only:
890-
assert not (
891-
is_multiheaded and self.separate_codebook_per_head
892-
), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
890+
assert not (is_multiheaded and self.separate_codebook_per_head), (
891+
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
892+
)
893893
unique_code_ids = torch.unique(embed_ind)
894894
codebook = codebook[:, unique_code_ids]
895895

@@ -999,9 +999,9 @@ def gumbel_sample(
999999
ind = sampling_logits.argmax(dim=dim)
10001000
one_hot = F.one_hot(ind, size).type(dtype)
10011001

1002-
assert not (
1003-
reinmax and not straight_through
1004-
), "reinmax can only be turned on if using straight through gumbel softmax"
1002+
assert not (reinmax and not straight_through), (
1003+
"reinmax can only be turned on if using straight through gumbel softmax"
1004+
)
10051005

10061006
if not straight_through or temperature <= 0.0 or not training:
10071007
return ind, one_hot
@@ -1209,9 +1209,9 @@ def __init__(
12091209
self.gumbel_sample = gumbel_sample
12101210
self.sample_codebook_temp = sample_codebook_temp
12111211

1212-
assert not (
1213-
use_ddp and num_codebooks > 1 and kmeans_init
1214-
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
1212+
assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
1213+
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
1214+
)
12151215

12161216
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
12171217
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop

lerobot/common/robot_devices/control_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
3333

3434
def log_dt(shortname, dt_val_s):
3535
nonlocal log_items, fps
36-
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
36+
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
3737
if fps is not None:
3838
actual_fps = 1 / dt_val_s
3939
if actual_fps < fps - 1:

lerobot/common/utils/io_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _deserialize(target, source):
5858
# Check that they have exactly the same set of keys.
5959
if target.keys() != source.keys():
6060
raise ValueError(
61-
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
61+
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
6262
)
6363

6464
# Recursively update each key.

0 commit comments

Comments
 (0)