Skip to content

Commit d1c82b1

Browse files
committed
Add test
1 parent c7e0e65 commit d1c82b1

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/test_sklearn_metrics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import cebra
2929
import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra
30+
import cebra.integrations.sklearn.helpers as cebra_sklearn_helpers
3031
import cebra.integrations.sklearn.metrics as cebra_sklearn_metrics
3132

3233

@@ -385,6 +386,36 @@ def test_sklearn_runs_consistency():
385386
invalid_embeddings_runs, between="runs")
386387

387388

389+
def test_align_embeddings():
390+
# Example data
391+
np.random.seed(42)
392+
embedding1 = np.random.uniform(0, 1, (10000, 4))
393+
embedding2 = np.random.uniform(0, 1, (10000, 10))
394+
embedding3 = np.random.uniform(0, 1, (8000, 6))
395+
embeddings_datasets = [embedding1, embedding2, embedding3]
396+
397+
labels1 = np.random.uniform(0, 1, (10000,))
398+
labels2 = np.random.uniform(0, 1, (10000,))
399+
labels3 = np.random.uniform(0, 1, (8000,))
400+
labels_datasets = [labels1, labels2, labels3]
401+
402+
embeddings = cebra_sklearn_helpers.align_embeddings(
403+
embeddings=embeddings_datasets,
404+
labels=labels_datasets,
405+
normalize=False,
406+
n_bins=100)
407+
408+
normalized_embeddings = cebra_sklearn_helpers.align_embeddings(
409+
embeddings=embeddings_datasets,
410+
labels=labels_datasets,
411+
normalize=True,
412+
n_bins=100)
413+
414+
assert len(embeddings) == len(embeddings_datasets)
415+
assert len(normalized_embeddings) == len(embeddings_datasets)
416+
assert len(embeddings) == len(normalized_embeddings)
417+
418+
388419
@pytest.mark.parametrize("seed", [42, 24, 10])
389420
def test_goodness_of_fit_score(seed):
390421
"""

0 commit comments

Comments
 (0)