Skip to content

Commit b8d9424

Browse files
authored
Add missing error raising and test for adapt on Multissession (#114)
Add missing error raising and test
1 parent e2d9831 commit b8d9424

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def _adapt_model(
935935

936936
dataset, is_multisession = self._prepare_data(X, y)
937937

938-
if is_multisession:
938+
if is_multisession or isinstance(self.model_, nn.ModuleList):
939939
raise NotImplementedError(
940940
"The adapt option with a multisession training is not handled. Please use adapt=True for single-trained estimators only."
941941
)

tests/test_sklearn.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,48 @@ def check_first_layer_dim(model, X):
675675
cebra_model.fit([X, X_s2], [y_c1, y_c1_s2], adapt=True)
676676

677677

678+
@_util.parametrize_slow(
679+
arg_names="model_architecture,device",
680+
fast_arguments=list(
681+
itertools.islice(
682+
itertools.product(
683+
cebra_sklearn_cebra.CEBRA.supported_model_architectures(),
684+
_DEVICES),
685+
2,
686+
)),
687+
slow_arguments=list(
688+
itertools.product(
689+
cebra_sklearn_cebra.CEBRA.supported_model_architectures(),
690+
_DEVICES)),
691+
)
692+
def test_sklearn_adapt_multisession(model_architecture, device):
693+
num_hidden_units = 32
694+
cebra_model = cebra_sklearn_cebra.CEBRA(
695+
model_architecture=model_architecture,
696+
time_offsets=10,
697+
learning_rate=3e-4,
698+
max_iterations=5,
699+
max_adapt_iterations=1,
700+
device=device,
701+
output_dimension=4,
702+
num_hidden_units=num_hidden_units,
703+
batch_size=42,
704+
verbose=True,
705+
)
706+
707+
# example dataset
708+
Xs = [np.random.uniform(0, 1, (1000, 50)) for i in range(3)]
709+
ys = [np.random.uniform(0, 1, (1000, 5)) for i in range(3)]
710+
711+
X_new = np.random.uniform(0, 1, (1000, 50))
712+
y_new = np.random.uniform(0, 1, (1000, 5))
713+
714+
cebra_model.fit(Xs, ys)
715+
716+
with pytest.raises(NotImplementedError, match=".*multisession.*"):
717+
cebra_model.fit(X_new, y_new, adapt=True)
718+
719+
678720
@_util.parametrize_slow(
679721
arg_names="model_architecture,device,pad_before_transform",
680722
fast_arguments=list(

0 commit comments

Comments
 (0)