Skip to content

Commit 3acbdf4

Browse files
stesCeliaBenquet
authored andcommitted
fix linting errors
1 parent 7dfd4b9 commit 3acbdf4

File tree

3 files changed

+18
-23
lines changed

3 files changed

+18
-23
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _safe_torch_load(filename, weights_only, **kwargs):
7777

7878

7979

80+
8081
def _init_loader(
8182
is_cont: bool,
8283
is_disc: bool,

tests/test_sklearn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,11 +1375,9 @@ def test_new_transform(model_architecture, device):
13751375
# example dataset
13761376
X = np.random.uniform(0, 1, (1000, 50))
13771377
X_s2 = np.random.uniform(0, 1, (800, 30))
1378-
X_s3 = np.random.uniform(0, 1, (1000, 30))
13791378
y_c1 = np.random.uniform(0, 1, (1000, 5))
13801379
y_c1_s2 = np.random.uniform(0, 1, (800, 5))
13811380
y_c2 = np.random.uniform(0, 1, (1000, 2))
1382-
y_c2_s2 = np.random.uniform(0, 1, (800, 2))
13831381
y_d = np.random.randint(0, 10, (1000,))
13841382
y_d_s2 = np.random.randint(0, 10, (800,))
13851383

tests/test_solver.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
166166

167167
solver.fit(loader)
168168

169-
assert solver.num_sessions == None
169+
assert solver.num_sessions is None
170170
assert solver.n_features == X.shape[1]
171171

172172
embedding = solver.transform(X)
@@ -231,25 +231,25 @@ def test_single_session_auxvar(data_name, loader_initfunc, model_architecture, s
231231

232232
pytest.skip("Not yet supported")
233233

234-
loader = _get_loader(data_name, loader_initfunc)
235-
model = _make_model(loader.dataset)
236-
behavior_model = _make_behavior_model(loader.dataset) # noqa: F841
234+
# loader = _get_loader(data_name, loader_initfunc)
235+
# model = _make_model(loader.dataset)
236+
# behavior_model = _make_behavior_model(loader.dataset) # noqa: F841
237237

238-
criterion = cebra.models.InfoNCE()
239-
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
238+
# criterion = cebra.models.InfoNCE()
239+
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
240240

241-
solver = solver_initfunc(
242-
model=model,
243-
criterion=criterion,
244-
optimizer=optimizer,
245-
)
241+
# solver = solver_initfunc(
242+
# model=model,
243+
# criterion=criterion,
244+
# optimizer=optimizer,
245+
# )
246246

247-
batch = next(iter(loader))
248-
assert batch.reference.shape == (32, loader.dataset.input_dimension, 10)
249-
log = solver.step(batch)
250-
assert isinstance(log, dict)
247+
# batch = next(iter(loader))
248+
# assert batch.reference.shape == (32, loader.dataset.input_dimension, 10)
249+
# log = solver.step(batch)
250+
# assert isinstance(log, dict)
251251

252-
solver.fit(loader)
252+
# solver.fit(loader)
253253

254254

255255
@pytest.mark.parametrize(
@@ -280,7 +280,7 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture,
280280

281281
solver.fit(loader)
282282

283-
assert solver.num_sessions == None
283+
assert solver.num_sessions is None
284284
assert solver.n_features == X.shape[1]
285285

286286
embedding = solver.transform(X)
@@ -721,7 +721,6 @@ def test_select_model_single_session(data_name, model_name, session_id,
721721
dataset = cebra.datasets.init(data_name)
722722
model = create_model(model_name, dataset.input_dimension)
723723
dataset.configure_for(model)
724-
loader = _get_loader(dataset, loader_initfunc=loader_initfunc)
725724
offset = model.get_offset()
726725
solver = solver_initfunc(model=model, criterion=None, optimizer=None)
727726

@@ -841,7 +840,6 @@ def test_batched_transform_single_session(
841840

842841
smallest_batch_length = loader.dataset.neural.shape[0] - batch_size
843842
offset_ = model.get_offset()
844-
padding_left = offset_.left if padding else 0
845843

846844
if smallest_batch_length <= len(offset_):
847845
with pytest.raises(ValueError):
@@ -893,7 +891,6 @@ def test_batched_transform_multi_session(data_name, model_name, padding,
893891

894892
smallest_batch_length = n_samples - batch_size
895893
offset_ = model[0].get_offset()
896-
padding_left = offset_.left if padding else 0
897894
for d in dataset._datasets:
898895
d.offset = offset_
899896
loader_kwargs = dict(num_steps=10, batch_size=32)
@@ -918,7 +915,6 @@ def test_batched_transform_multi_session(data_name, model_name, padding,
918915
pad_before_transform=padding)
919916

920917
else:
921-
model_ = model[i]
922918
embedding = solver.transform(inputs=inputs.neural,
923919
session_id=i,
924920
pad_before_transform=padding)

0 commit comments

Comments
 (0)