Skip to content

Commit 0332910

Browse files
committed
fix linting errors
1 parent b1980cd commit 0332910

File tree

3 files changed

+20
-24
lines changed

3 files changed

+20
-24
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
import numpy as np
2929
import numpy.typing as npt
3030
import pkg_resources
31+
import sklearn
3132
import sklearn.utils.validation as sklearn_utils_validation
3233
import torch
33-
import sklearn
3434
from sklearn.base import BaseEstimator
3535
from sklearn.base import TransformerMixin
3636
from sklearn.utils.metaestimators import available_if
@@ -43,12 +43,14 @@
4343
import cebra.models
4444
import cebra.solver
4545

46+
4647
def check_version(estimator):
4748
# NOTE(stes): required as a check for the old way of specifying tags
4849
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
4950
from packaging import version
5051
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
5152

53+
5254
def _init_loader(
5355
is_cont: bool,
5456
is_disc: bool,

tests/test_sklearn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,11 +1375,9 @@ def test_new_transform(model_architecture, device):
13751375
# example dataset
13761376
X = np.random.uniform(0, 1, (1000, 50))
13771377
X_s2 = np.random.uniform(0, 1, (800, 30))
1378-
X_s3 = np.random.uniform(0, 1, (1000, 30))
13791378
y_c1 = np.random.uniform(0, 1, (1000, 5))
13801379
y_c1_s2 = np.random.uniform(0, 1, (800, 5))
13811380
y_c2 = np.random.uniform(0, 1, (1000, 2))
1382-
y_c2_s2 = np.random.uniform(0, 1, (800, 2))
13831381
y_d = np.random.randint(0, 10, (1000,))
13841382
y_d_s2 = np.random.randint(0, 10, (800,))
13851383

tests/test_solver.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
163163

164164
solver.fit(loader)
165165

166-
assert solver.num_sessions == None
166+
assert solver.num_sessions is None
167167
assert solver.n_features == X.shape[1]
168168

169169
embedding = solver.transform(X)
@@ -202,25 +202,25 @@ def test_single_session_auxvar(data_name, loader_initfunc, model_architecture,
202202

203203
pytest.skip("Not yet supported")
204204

205-
loader = _get_loader(data_name, loader_initfunc)
206-
model = _make_model(loader.dataset)
207-
behavior_model = _make_behavior_model(loader.dataset) # noqa: F841
205+
# loader = _get_loader(data_name, loader_initfunc)
206+
# model = _make_model(loader.dataset)
207+
# behavior_model = _make_behavior_model(loader.dataset) # noqa: F841
208208

209-
criterion = cebra.models.InfoNCE()
210-
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
209+
# criterion = cebra.models.InfoNCE()
210+
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
211211

212-
solver = solver_initfunc(
213-
model=model,
214-
criterion=criterion,
215-
optimizer=optimizer,
216-
)
212+
# solver = solver_initfunc(
213+
# model=model,
214+
# criterion=criterion,
215+
# optimizer=optimizer,
216+
# )
217217

218-
batch = next(iter(loader))
219-
assert batch.reference.shape == (32, loader.dataset.input_dimension, 10)
220-
log = solver.step(batch)
221-
assert isinstance(log, dict)
218+
# batch = next(iter(loader))
219+
# assert batch.reference.shape == (32, loader.dataset.input_dimension, 10)
220+
# log = solver.step(batch)
221+
# assert isinstance(log, dict)
222222

223-
solver.fit(loader)
223+
# solver.fit(loader)
224224

225225

226226
@pytest.mark.parametrize(
@@ -251,7 +251,7 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture,
251251

252252
solver.fit(loader)
253253

254-
assert solver.num_sessions == None
254+
assert solver.num_sessions is None
255255
assert solver.n_features == X.shape[1]
256256

257257
embedding = solver.transform(X)
@@ -513,7 +513,6 @@ def test_select_model_single_session(data_name, model_name, session_id,
513513
dataset = cebra.datasets.init(data_name)
514514
model = create_model(model_name, dataset.input_dimension)
515515
dataset.configure_for(model)
516-
loader = _get_loader(dataset, loader_initfunc=loader_initfunc)
517516
offset = model.get_offset()
518517
solver = solver_initfunc(model=model, criterion=None, optimizer=None)
519518

@@ -633,7 +632,6 @@ def test_batched_transform_single_session(
633632

634633
smallest_batch_length = loader.dataset.neural.shape[0] - batch_size
635634
offset_ = model.get_offset()
636-
padding_left = offset_.left if padding else 0
637635

638636
if smallest_batch_length <= len(offset_):
639637
with pytest.raises(ValueError):
@@ -685,7 +683,6 @@ def test_batched_transform_multi_session(data_name, model_name, padding,
685683

686684
smallest_batch_length = n_samples - batch_size
687685
offset_ = model[0].get_offset()
688-
padding_left = offset_.left if padding else 0
689686
for d in dataset._datasets:
690687
d.offset = offset_
691688
loader_kwargs = dict(num_steps=10, batch_size=32)
@@ -710,7 +707,6 @@ def test_batched_transform_multi_session(data_name, model_name, padding,
710707
pad_before_transform=padding)
711708

712709
else:
713-
model_ = model[i]
714710
embedding = solver.transform(inputs=inputs.neural,
715711
session_id=i,
716712
pad_before_transform=padding)

0 commit comments

Comments
 (0)