|
27 | 27 |
|
28 | 28 | import cebra |
29 | 29 | import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra |
| 30 | +import cebra.integrations.sklearn.helpers as cebra_sklearn_helpers |
30 | 31 | import cebra.integrations.sklearn.metrics as cebra_sklearn_metrics |
31 | 32 |
|
32 | 33 |
|
@@ -385,6 +386,36 @@ def test_sklearn_runs_consistency(): |
385 | 386 | invalid_embeddings_runs, between="runs") |
386 | 387 |
|
387 | 388 |
|
| 389 | +def test_align_embeddings(): |
| 390 | + # Example data |
| 391 | + np.random.seed(42) |
| 392 | + embedding1 = np.random.uniform(0, 1, (10000, 4)) |
| 393 | + embedding2 = np.random.uniform(0, 1, (10000, 10)) |
| 394 | + embedding3 = np.random.uniform(0, 1, (8000, 6)) |
| 395 | + embeddings_datasets = [embedding1, embedding2, embedding3] |
| 396 | + |
| 397 | + labels1 = np.random.uniform(0, 1, (10000,)) |
| 398 | + labels2 = np.random.uniform(0, 1, (10000,)) |
| 399 | + labels3 = np.random.uniform(0, 1, (8000,)) |
| 400 | + labels_datasets = [labels1, labels2, labels3] |
| 401 | + |
| 402 | + embeddings = cebra_sklearn_helpers.align_embeddings( |
| 403 | + embeddings=embeddings_datasets, |
| 404 | + labels=labels_datasets, |
| 405 | + normalize=False, |
| 406 | + n_bins=100) |
| 407 | + |
| 408 | + normalized_embeddings = cebra_sklearn_helpers.align_embeddings( |
| 409 | + embeddings=embeddings_datasets, |
| 410 | + labels=labels_datasets, |
| 411 | + normalize=True, |
| 412 | + n_bins=100) |
| 413 | + |
| 414 | + assert len(embeddings) == len(embeddings_datasets) |
| 415 | + assert len(normalized_embeddings) == len(embeddings_datasets) |
| 416 | + assert len(embeddings) == len(normalized_embeddings) |
| 417 | + |
| 418 | + |
388 | 419 | @pytest.mark.parametrize("seed", [42, 24, 10]) |
389 | 420 | def test_goodness_of_fit_score(seed): |
390 | 421 | """ |
|
0 commit comments