@@ -856,6 +856,39 @@ def get_ordered_cuda_devices():
856856ordered_cuda_devices = get_ordered_cuda_devices () if torch .cuda .is_available (
857857) else []
858858
859+ def test_fit_after_moving_to_device ():
860+ expected_device = 'cpu'
861+ expected_type = type (expected_device )
862+
863+ X = np .random .uniform (0 , 1 , (10 , 5 ))
864+ cebra_model = cebra_sklearn_cebra .CEBRA (model_architecture = "offset1-model" ,
865+ max_iterations = 5 ,
866+ device = expected_device )
867+
868+ assert type (cebra_model .device ) == expected_type
869+ assert cebra_model .device == expected_device
870+
871+ cebra_model .partial_fit (X )
872+ assert type (cebra_model .device ) == expected_type
873+ assert cebra_model .device == expected_device
874+ if hasattr (cebra_model , 'device_' ):
875+ assert type (cebra_model .device_ ) == expected_type
876+ assert cebra_model .device_ == expected_device
877+
878+ # Move the model to device using the to() method
879+ cebra_model .to ('cpu' )
880+ assert type (cebra_model .device ) == expected_type
881+ assert cebra_model .device == expected_device
882+ if hasattr (cebra_model , 'device_' ):
883+ assert type (cebra_model .device_ ) == expected_type
884+ assert cebra_model .device_ == expected_device
885+
886+ cebra_model .partial_fit (X )
887+ assert type (cebra_model .device ) == expected_type
888+ assert cebra_model .device == expected_device
889+ if hasattr (cebra_model , 'device_' ):
890+ assert type (cebra_model .device_ ) == expected_type
891+ assert cebra_model .device_ == expected_device
859892
860893@pytest .mark .parametrize ("device" , ['cpu' ] + ordered_cuda_devices )
861894def test_move_cpu_to_cuda_device (device ):
@@ -875,9 +908,12 @@ def test_move_cpu_to_cuda_device(device):
875908 new_device = 'cpu' if device .startswith ('cuda' ) else 'cuda:0'
876909 cebra_model .to (new_device )
877910
878- assert cebra_model .device == torch .device (new_device )
879- assert next (cebra_model .solver_ .model .parameters ()).device == torch .device (
880- new_device )
911+ assert cebra_model .device == new_device
912+ device_model = next (cebra_model .solver_ .model .parameters ()).device
913+ device_str = str (device_model )
914+ if device_model .type == 'cuda' :
915+ device_str = f'cuda:{ device_model .index } '
916+ assert device_str == new_device
881917
882918 with tempfile .NamedTemporaryFile (mode = "w+b" , delete = True ) as savefile :
883919 cebra_model .save (savefile .name )
@@ -903,9 +939,10 @@ def test_move_cpu_to_mps_device(device):
903939 new_device = 'cpu' if device == 'mps' else 'mps'
904940 cebra_model .to (new_device )
905941
906- assert cebra_model .device == torch .device (new_device )
907- assert next (cebra_model .solver_ .model .parameters ()).device == torch .device (
908- new_device )
942+ assert cebra_model .device == new_device
943+
944+ device_model = next (cebra_model .solver_ .model .parameters ()).device
945+ assert device_model .type == new_device
909946
910947 with tempfile .NamedTemporaryFile (mode = "w+b" , delete = True ) as savefile :
911948 cebra_model .save (savefile .name )
@@ -939,9 +976,12 @@ def test_move_mps_to_cuda_device(device):
939976 new_device = 'mps' if device .startswith ('cuda' ) else 'cuda:0'
940977 cebra_model .to (new_device )
941978
942- assert cebra_model .device == torch .device (new_device )
943- assert next (cebra_model .solver_ .model .parameters ()).device == torch .device (
944- new_device )
979+ assert cebra_model .device == new_device
980+ device_model = next (cebra_model .solver_ .model .parameters ()).device
981+ device_str = str (device_model )
982+ if device_model .type == 'cuda' :
983+ device_str = f'cuda:{ device_model .index } '
984+ assert device_str == new_device
945985
946986 with tempfile .NamedTemporaryFile (mode = "w+b" , delete = True ) as savefile :
947987 cebra_model .save (savefile .name )
@@ -963,11 +1003,7 @@ def test_mps():
9631003
9641004 if torch .backends .mps .is_available () and torch .backends .mps .is_built ():
9651005 torch .backends .mps .is_available = lambda : False
966- with pytest .raises (ValueError ):
967- cebra_model .fit (X )
968-
969- torch .backends .mps .is_available = lambda : True
970- torch .backends .mps .is_built = lambda : False
1006+
9711007 with pytest .raises (ValueError ):
9721008 cebra_model .fit (X )
9731009
0 commit comments