@@ -675,6 +675,48 @@ def check_first_layer_dim(model, X):
675675 cebra_model .fit ([X , X_s2 ], [y_c1 , y_c1_s2 ], adapt = True )
676676
677677
678+ @_util .parametrize_slow (
679+ arg_names = "model_architecture,device" ,
680+ fast_arguments = list (
681+ itertools .islice (
682+ itertools .product (
683+ cebra_sklearn_cebra .CEBRA .supported_model_architectures (),
684+ _DEVICES ),
685+ 2 ,
686+ )),
687+ slow_arguments = list (
688+ itertools .product (
689+ cebra_sklearn_cebra .CEBRA .supported_model_architectures (),
690+ _DEVICES )),
691+ )
692+ def test_sklearn_adapt_multisession (model_architecture , device ):
693+ num_hidden_units = 32
694+ cebra_model = cebra_sklearn_cebra .CEBRA (
695+ model_architecture = model_architecture ,
696+ time_offsets = 10 ,
697+ learning_rate = 3e-4 ,
698+ max_iterations = 5 ,
699+ max_adapt_iterations = 1 ,
700+ device = device ,
701+ output_dimension = 4 ,
702+ num_hidden_units = num_hidden_units ,
703+ batch_size = 42 ,
704+ verbose = True ,
705+ )
706+
707+ # example dataset
708+ Xs = [np .random .uniform (0 , 1 , (1000 , 50 )) for i in range (3 )]
709+ ys = [np .random .uniform (0 , 1 , (1000 , 5 )) for i in range (3 )]
710+
711+ X_new = np .random .uniform (0 , 1 , (1000 , 50 ))
712+ y_new = np .random .uniform (0 , 1 , (1000 , 5 ))
713+
714+ cebra_model .fit (Xs , ys )
715+
716+ with pytest .raises (NotImplementedError , match = ".*multisession.*" ):
717+ cebra_model .fit (X_new , y_new , adapt = True )
718+
719+
678720@_util .parametrize_slow (
679721 arg_names = "model_architecture,device,pad_before_transform" ,
680722 fast_arguments = list (
0 commit comments