Skip to content

Commit c94f5ae

Browse files
committed
Add tests for exception handling
1 parent 040b5a2 commit c94f5ae

File tree

1 file changed

+66
-2
lines changed

1 file changed

+66
-2
lines changed

tests/test_sklearn_metrics.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,9 @@ def test_goodness_of_fit_score(seed):
395395
max_iterations=5,
396396
batch_size=512,
397397
)
398-
X = torch.tensor(np.random.uniform(0, 1, (5000, 50)))
399-
y = torch.tensor(np.random.uniform(0, 1, (5000, 5)))
398+
generator = torch.Generator().manual_seed(seed)
399+
X = torch.rand(5000, 50, dtype=torch.float32, generator=generator)
400+
y = torch.rand(5000, 5, dtype=torch.float32, generator=generator)
400401
cebra_model.fit(X, y)
401402
score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model,
402403
X,
@@ -447,3 +448,66 @@ def _fit_and_get_history(X, y):
447448
assert history_linear.shape[0] > 0
448449

449450
assert np.all(history_linear[-20:] > history_random[-20:])
451+
452+
453+
@pytest.mark.parametrize("seed", [42, 24, 10])
454+
def test_infonce_to_goodness_of_fit(seed):
455+
"""Test the conversion from InfoNCE loss to goodness of fit metric."""
456+
# Test with model
457+
cebra_model = cebra_sklearn_cebra.CEBRA(
458+
model_architecture="offset10-model",
459+
max_iterations=5,
460+
batch_size=128,
461+
)
462+
generator = torch.Generator().manual_seed(seed)
463+
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
464+
cebra_model.fit(X)
465+
466+
# Test single value
467+
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
468+
model=cebra_model)
469+
assert isinstance(gof, float)
470+
471+
# Test array of values
472+
infonce_values = np.array([1.0, 2.0, 3.0])
473+
gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit(
474+
infonce_values, model=cebra_model)
475+
assert isinstance(gof_array, np.ndarray)
476+
assert gof_array.shape == infonce_values.shape
477+
478+
# Test with explicit batch_size and num_sessions
479+
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
480+
batch_size=128,
481+
num_sessions=1)
482+
assert isinstance(gof, float)
483+
484+
# Test error cases
485+
with pytest.raises(ValueError, match="batch_size.*should not be provided"):
486+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
487+
model=cebra_model,
488+
batch_size=128)
489+
490+
with pytest.raises(ValueError, match="batch_size.*should not be provided"):
491+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
492+
model=cebra_model,
493+
num_sessions=1)
494+
495+
# Test with unfitted model
496+
unfitted_model = cebra_sklearn_cebra.CEBRA()
497+
with pytest.raises(RuntimeError, match="Fit the CEBRA model first"):
498+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
499+
model=unfitted_model)
500+
501+
# Test with model having batch_size=None
502+
none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None)
503+
none_batch_model.fit(X)
504+
with pytest.raises(ValueError, match="Computing the goodness of fit"):
505+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
506+
model=none_batch_model)
507+
508+
# Test missing batch_size or num_sessions when model is None
509+
with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
510+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, batch_size=128)
511+
512+
with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
513+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, num_sessions=1)

0 commit comments

Comments
 (0)