Skip to content

Commit d37e7f9

Browse files
authored
Expose n_bins argument from align_embeddings (#25)
* Expose n_bins argument from align_embeddings * Fix docs * Add sklearn helper functions to public docs
1 parent e011694 commit d37e7f9

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _consistency_scores(
110110
Args:
111111
embeddings: List of embedding matrices.
112112
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
113-
associated to the same dataset.
113+
associated to the same dataset.
114114
115115
Returns:
116116
List of the consistencies for each embeddings pair (first element) and
@@ -145,6 +145,7 @@ def _consistency_datasets(
145145
embeddings: List[Union[npt.NDArray, torch.Tensor]],
146146
dataset_ids: Optional[List[Union[int, str, float]]],
147147
labels: List[Union[npt.NDArray, torch.Tensor]],
148+
num_discretization_bins: int = 100
148149
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
149150
"""Compute consistency between embeddings from different datasets.
150151
@@ -158,9 +159,14 @@ def _consistency_datasets(
158159
Args:
159160
embeddings: List of embedding matrices.
160161
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
161-
associated to the same dataset.
162+
associated to the same dataset.
162163
labels: List of labels corresponding to each embedding and to use for alignment
163164
between them.
165+
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
166+
for embedding alignment. Also see the ``n_bins`` argument in
167+
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
168+
parameter is used internally. This argument is only used if ``labels``
169+
is not ``None`` and the given labels are continuous and not already discrete.
164170
165171
Returns:
166172
A list of scores obtained between embeddings from different datasets (first element),
@@ -203,7 +209,7 @@ def _consistency_datasets(
203209

204210
# NOTE(celia): with default values normalized=True and n_bins = 100
205211
aligned_embeddings = cebra_sklearn_helpers.align_embeddings(
206-
embeddings, labels)
212+
embeddings, labels, n_bins=num_discretization_bins)
207213
scores, pairs = _consistency_scores(aligned_embeddings,
208214
datasets=dataset_ids)
209215
between_dataset = [p[0] != p[1] for p in pairs]
@@ -303,6 +309,7 @@ def consistency_score(
303309
between: Optional[Literal["datasets", "runs"]] = None,
304310
labels: Optional[List[Union[npt.NDArray, torch.Tensor]]] = None,
305311
dataset_ids: Optional[List[Union[int, str, float]]] = None,
312+
num_discretization_bins: int = 100
306313
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
307314
"""Compute the consistency score between embeddings, either between runs or between datasets.
308315
@@ -320,6 +327,12 @@ def consistency_score(
320327
*Consistency between runs* means the consistency between embeddings obtained from multiple models
321328
trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings
322329
obtained from models trained on **different datasets**, such as different animals, sessions, etc.
330+
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
331+
for embedding alignment. Also see the ``n_bins`` argument in
332+
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
333+
parameter is used internally. This argument is only used if ``labels``
334+
is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels
335+
are continuous and not already discrete.
323336
324337
Returns:
325338
The list of scores computed between the embeddings (first returns), the list of pairs corresponding
@@ -356,12 +369,16 @@ def consistency_score(
356369
if labels is not None:
357370
raise ValueError(
358371
f"No labels should be provided for between-runs consistency.")
359-
scores, pairs, datasets = _consistency_runs(embeddings=embeddings,
360-
dataset_ids=dataset_ids)
372+
scores, pairs, datasets = _consistency_runs(
373+
embeddings=embeddings,
374+
dataset_ids=dataset_ids,
375+
)
361376
elif between == "datasets":
362-
scores, pairs, datasets = _consistency_datasets(embeddings=embeddings,
363-
dataset_ids=dataset_ids,
364-
labels=labels)
377+
scores, pairs, datasets = _consistency_datasets(
378+
embeddings=embeddings,
379+
dataset_ids=dataset_ids,
380+
labels=labels,
381+
num_discretization_bins=num_discretization_bins)
365382
else:
366383
raise NotImplementedError(
367384
f"Invalid comparison, got between={between}, expects either datasets or runs."

docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ these components in other contexts and research code bases.
2525
api/sklearn/cebra
2626
api/sklearn/metrics
2727
api/sklearn/decoder
28+
api/sklearn/helpers
2829

2930

3031

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Helper functions
2+
----------------
3+
4+
.. automodule:: cebra.integrations.sklearn.helpers
5+
:show-inheritance:
6+
:members:
7+

0 commit comments

Comments
 (0)