@@ -383,3 +383,132 @@ def test_sklearn_runs_consistency():
383383 with pytest .raises (ValueError , match = "Invalid.*embeddings" ):
384384 _ , _ , _ = cebra_sklearn_metrics .consistency_score (
385385 invalid_embeddings_runs , between = "runs" )
386+
387+
388+ @pytest .mark .parametrize ("seed" , [42 , 24 , 10 ])
389+ def test_goodness_of_fit_score (seed ):
390+ """
391+ Ensure that the GoF score is close to 0 for a model fit on random data.
392+ """
393+ cebra_model = cebra_sklearn_cebra .CEBRA (
394+ model_architecture = "offset1-model" ,
395+ max_iterations = 5 ,
396+ batch_size = 512 ,
397+ )
398+ generator = torch .Generator ().manual_seed (seed )
399+ X = torch .rand (5000 , 50 , dtype = torch .float32 , generator = generator )
400+ y = torch .rand (5000 , 5 , dtype = torch .float32 , generator = generator )
401+ cebra_model .fit (X , y )
402+ score = cebra_sklearn_metrics .goodness_of_fit_score (cebra_model ,
403+ X ,
404+ y ,
405+ session_id = 0 ,
406+ num_batches = 500 )
407+ assert isinstance (score , float )
408+ assert np .isclose (score , 0 , atol = 0.01 )
409+
410+
411+ @pytest .mark .parametrize ("seed" , [42 , 24 , 10 ])
412+ def test_goodness_of_fit_history (seed ):
413+ """
414+ Ensure that the GoF score is higher for a model fit on data with underlying
415+ structure than for a model fit on random data.
416+ """
417+
418+ # Generate data
419+ generator = torch .Generator ().manual_seed (seed )
420+ X = torch .rand (1000 , 50 , dtype = torch .float32 , generator = generator )
421+ y_random = torch .rand (len (X ), 5 , dtype = torch .float32 , generator = generator )
422+ linear_map = torch .randn (50 , 5 , dtype = torch .float32 , generator = generator )
423+ y_linear = X @ linear_map
424+
425+ def _fit_and_get_history (X , y ):
426+ cebra_model = cebra_sklearn_cebra .CEBRA (
427+ model_architecture = "offset1-model" ,
428+ max_iterations = 150 ,
429+ batch_size = 512 ,
430+ device = "cpu" )
431+ cebra_model .fit (X , y )
432+ history = cebra_sklearn_metrics .goodness_of_fit_history (cebra_model )
433+ # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
434+ # due to numerical issues.
435+ return history [5 :]
436+
437+ history_random = _fit_and_get_history (X , y_random )
438+ history_linear = _fit_and_get_history (X , y_linear )
439+
440+ assert isinstance (history_random , np .ndarray )
441+ assert history_random .shape [0 ] > 0
442+ # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
443+ # due to numerical issues.
444+ history_random_non_negative = history_random [history_random >= 0 ]
445+ np .testing .assert_allclose (history_random_non_negative , 0 , atol = 0.075 )
446+
447+ assert isinstance (history_linear , np .ndarray )
448+ assert history_linear .shape [0 ] > 0
449+
450+ assert np .all (history_linear [- 20 :] > history_random [- 20 :])
451+
452+
453+ @pytest .mark .parametrize ("seed" , [42 , 24 , 10 ])
454+ def test_infonce_to_goodness_of_fit (seed ):
455+ """Test the conversion from InfoNCE loss to goodness of fit metric."""
456+ # Test with model
457+ cebra_model = cebra_sklearn_cebra .CEBRA (
458+ model_architecture = "offset10-model" ,
459+ max_iterations = 5 ,
460+ batch_size = 128 ,
461+ )
462+ generator = torch .Generator ().manual_seed (seed )
463+ X = torch .rand (1000 , 50 , dtype = torch .float32 , generator = generator )
464+ cebra_model .fit (X )
465+
466+ # Test single value
467+ gof = cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
468+ model = cebra_model )
469+ assert isinstance (gof , float )
470+
471+ # Test array of values
472+ infonce_values = np .array ([1.0 , 2.0 , 3.0 ])
473+ gof_array = cebra_sklearn_metrics .infonce_to_goodness_of_fit (
474+ infonce_values , model = cebra_model )
475+ assert isinstance (gof_array , np .ndarray )
476+ assert gof_array .shape == infonce_values .shape
477+
478+ # Test with explicit batch_size and num_sessions
479+ gof = cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
480+ batch_size = 128 ,
481+ num_sessions = 1 )
482+ assert isinstance (gof , float )
483+
484+ # Test error cases
485+ with pytest .raises (ValueError , match = "batch_size.*should not be provided" ):
486+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
487+ model = cebra_model ,
488+ batch_size = 128 )
489+
490+ with pytest .raises (ValueError , match = "batch_size.*should not be provided" ):
491+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
492+ model = cebra_model ,
493+ num_sessions = 1 )
494+
495+ # Test with unfitted model
496+ unfitted_model = cebra_sklearn_cebra .CEBRA (max_iterations = 5 )
497+ with pytest .raises (RuntimeError , match = "Fit the CEBRA model first" ):
498+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
499+ model = unfitted_model )
500+
501+ # Test with model having batch_size=None
502+ none_batch_model = cebra_sklearn_cebra .CEBRA (batch_size = None ,
503+ max_iterations = 5 )
504+ none_batch_model .fit (X )
505+ with pytest .raises (ValueError , match = "Computing the goodness of fit" ):
506+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
507+ model = none_batch_model )
508+
509+ # Test missing batch_size or num_sessions when model is None
510+ with pytest .raises (ValueError , match = "batch_size.*and num_sessions" ):
511+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 , batch_size = 128 )
512+
513+ with pytest .raises (ValueError , match = "batch_size.*and num_sessions" ):
514+ cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 , num_sessions = 1 )
0 commit comments