@@ -395,8 +395,9 @@ def test_goodness_of_fit_score(seed):
395395 max_iterations = 5 ,
396396 batch_size = 512 ,
397397 )
398- X = torch .tensor (np .random .uniform (0 , 1 , (5000 , 50 )))
399- y = torch .tensor (np .random .uniform (0 , 1 , (5000 , 5 )))
398+ generator = torch .Generator ().manual_seed (seed )
399+ X = torch .rand (5000 , 50 , dtype = torch .float32 , generator = generator )
400+ y = torch .rand (5000 , 5 , dtype = torch .float32 , generator = generator )
400401 cebra_model .fit (X , y )
401402 score = cebra_sklearn_metrics .goodness_of_fit_score (cebra_model ,
402403 X ,
@@ -447,3 +448,66 @@ def _fit_and_get_history(X, y):
447448 assert history_linear .shape [0 ] > 0
448449
449450 assert np .all (history_linear [- 20 :] > history_random [- 20 :])
451+
452+
453+ @pytest .mark .parametrize ("seed" , [42 , 24 , 10 ])
454+ def test_infonce_to_goodness_of_fit (seed ):
455+ """Test the conversion from InfoNCE loss to goodness of fit metric."""
456+ # Test with model
457+ cebra_model = cebra_sklearn_cebra .CEBRA (
458+ model_architecture = "offset10-model" ,
459+ max_iterations = 5 ,
460+ batch_size = 128 ,
461+ )
462+ generator = torch .Generator ().manual_seed (seed )
463+ X = torch .rand (1000 , 50 , dtype = torch .float32 , generator = generator )
464+ cebra_model .fit (X )
465+
466+ # Test single value
467+ gof = cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
468+ model = cebra_model )
469+ assert isinstance (gof , float )
470+
471+ # Test array of values
472+ infonce_values = np .array ([1.0 , 2.0 , 3.0 ])
473+ gof_array = cebra_sklearn_metrics .infonce_to_goodness_of_fit (
474+ infonce_values , model = cebra_model )
475+ assert isinstance (gof_array , np .ndarray )
476+ assert gof_array .shape == infonce_values .shape
477+
478+ # Test with explicit batch_size and num_sessions
479+ gof = cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
480+ batch_size = 128 ,
481+ num_sessions = 1 )
482+ assert isinstance (gof , float )
483+
484+ # Test error cases
485+ with pytest .raises (ValueError , match = "batch_size.*should not be provided" ):
486+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
487+ model = cebra_model ,
488+ batch_size = 128 )
489+
490+ with pytest .raises (ValueError , match = "batch_size.*should not be provided" ):
491+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
492+ model = cebra_model ,
493+ num_sessions = 1 )
494+
495+ # Test with unfitted model
496+ unfitted_model = cebra_sklearn_cebra .CEBRA ()
497+ with pytest .raises (RuntimeError , match = "Fit the CEBRA model first" ):
498+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
499+ model = unfitted_model )
500+
501+ # Test with model having batch_size=None
502+ none_batch_model = cebra_sklearn_cebra .CEBRA (batch_size = None )
503+ none_batch_model .fit (X )
504+ with pytest .raises (ValueError , match = "Computing the goodness of fit" ):
505+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
506+ model = none_batch_model )
507+
508+ # Test missing batch_size or num_sessions when model is None
509+ with pytest .raises (ValueError , match = "batch_size.*and num_sessions" ):
510+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 , batch_size = 128 )
511+
512+ with pytest .raises (ValueError , match = "batch_size.*and num_sessions" ):
513+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 , num_sessions = 1 )
0 commit comments