Skip to content

Commit 7dfd4b9

Browse files
stesCeliaBenquet
authored andcommitted
apply ruff auto-fixes
1 parent 7b0cc68 commit 7dfd4b9

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

cebra/data/single_session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ def __post_init__(self):
370370

371371
self._init_behavior_distribution()
372372
self._init_time_distribution()
373+
374+
if self.conditional != "time_delta":
375+
raise NotImplementedError(
376+
"Hybrid training is currently only implemented using the ``time_delta`` "
377+
"continual distribution.")
373378

374379
def _init_behavior_distribution(self):
375380
if self.conditional == "time":

cebra/datasets/gaussian_mixture.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import cebra.data
2929
import cebra.io
30-
from cebra.datasets import get_datapath
3130
from cebra.datasets import parametrize
3231
from cebra.datasets import register
3332

cebra/solver/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@
3333
import abc
3434
import os
3535
import warnings
36-
from typing import Callable, Dict, Iterable, List, Literal, Optional, Union
36+
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
3737

3838
import literate_dataclasses as dataclasses
3939
import numpy.typing as npt
4040
import numpy as np
4141
import numpy.typing as npt
4242
import torch
4343
import torch.nn.functional as F
44-
import tqdm
4544
from torch.utils.data import DataLoader
4645
from torch.utils.data import Dataset
4746

tests/test_solver.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
# limitations under the License.
2121
#
2222
import copy
23-
import itertools
2423
import tempfile
2524

2625
import numpy as np
@@ -715,8 +714,8 @@ def create_model(model_name, input_dimension):
715714

716715
@pytest.mark.parametrize(
717716
"data_name, model_name ,session_id, loader_initfunc, solver_initfunc",
718-
single_session_tests_select_model +
719-
single_session_hybrid_tests_select_model)
717+
single_session_tests_select_model + single_session_hybrid_tests_select_model
718+
)
720719
def test_select_model_single_session(data_name, model_name, session_id,
721720
loader_initfunc, solver_initfunc):
722721
dataset = cebra.datasets.init(data_name)

0 commit comments

Comments
 (0)