Skip to content

Commit 4d68110

Browse files
authored
Merge branch 'main' into batched-inference-and-padding
2 parents 66fc6aa + a5814bb commit 4d68110

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

cebra/integrations/sklearn/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def align_embeddings(
155155
quantized_sample / np.linalg.norm(quantized_sample, axis=0)
156156
for quantized_sample in quantized_embedding
157157
]
158+
quantized_embeddings.append(quantized_embedding_norm)
159+
else:
160+
quantized_embeddings.append(quantized_embedding)
158161

159-
quantized_embeddings.append(quantized_embedding_norm)
160162
return quantized_embeddings

tests/test_sklearn_metrics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import cebra
2929
import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra
30+
import cebra.integrations.sklearn.helpers as cebra_sklearn_helpers
3031
import cebra.integrations.sklearn.metrics as cebra_sklearn_metrics
3132

3233

@@ -385,6 +386,36 @@ def test_sklearn_runs_consistency():
385386
invalid_embeddings_runs, between="runs")
386387

387388

389+
def test_align_embeddings():
390+
# Example data
391+
np.random.seed(42)
392+
embedding1 = np.random.uniform(0, 1, (10000, 4))
393+
embedding2 = np.random.uniform(0, 1, (10000, 10))
394+
embedding3 = np.random.uniform(0, 1, (8000, 6))
395+
embeddings_datasets = [embedding1, embedding2, embedding3]
396+
397+
labels1 = np.random.uniform(0, 1, (10000,))
398+
labels2 = np.random.uniform(0, 1, (10000,))
399+
labels3 = np.random.uniform(0, 1, (8000,))
400+
labels_datasets = [labels1, labels2, labels3]
401+
402+
embeddings = cebra_sklearn_helpers.align_embeddings(
403+
embeddings=embeddings_datasets,
404+
labels=labels_datasets,
405+
normalize=False,
406+
n_bins=100)
407+
408+
normalized_embeddings = cebra_sklearn_helpers.align_embeddings(
409+
embeddings=embeddings_datasets,
410+
labels=labels_datasets,
411+
normalize=True,
412+
n_bins=100)
413+
414+
assert len(embeddings) == len(embeddings_datasets)
415+
assert len(normalized_embeddings) == len(embeddings_datasets)
416+
assert len(embeddings) == len(normalized_embeddings)
417+
418+
388419
@pytest.mark.parametrize("seed", [42, 24, 10])
389420
def test_goodness_of_fit_score(seed):
390421
"""

0 commit comments

Comments
 (0)