@@ -110,7 +110,7 @@ def _consistency_scores(
110110 Args:
111111 embeddings: List of embedding matrices.
112112 dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
113- associated to the same dataset.
113+ associated to the same dataset.
114114
115115 Returns:
116116 List of the consistencies for each embeddings pair (first element) and
@@ -145,6 +145,7 @@ def _consistency_datasets(
145145 embeddings : List [Union [npt .NDArray , torch .Tensor ]],
146146 dataset_ids : Optional [List [Union [int , str , float ]]],
147147 labels : List [Union [npt .NDArray , torch .Tensor ]],
148+ num_discretization_bins : int = 100
148149) -> Tuple [npt .NDArray , npt .NDArray , npt .NDArray ]:
149150 """Compute consistency between embeddings from different datasets.
150151
@@ -158,9 +159,14 @@ def _consistency_datasets(
158159 Args:
159160 embeddings: List of embedding matrices.
160161 dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
161- associated to the same dataset.
162+ associated to the same dataset.
162163 labels: List of labels corresponding to each embedding and to use for alignment
163164 between them.
165+ num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
166+ for embedding alignment. Also see the ``n_bins`` argument in
167+ :py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
168+ parameter is used internally. This argument is only used if ``labels``
169+ is not ``None`` and the given labels are continuous and not already discrete.
164170
165171 Returns:
166172 A list of scores obtained between embeddings from different datasets (first element),
@@ -203,7 +209,7 @@ def _consistency_datasets(
203209
204210 # NOTE(celia): with default values normalized=True and n_bins = 100
205211 aligned_embeddings = cebra_sklearn_helpers .align_embeddings (
206- embeddings , labels )
212+ embeddings , labels , n_bins = num_discretization_bins )
207213 scores , pairs = _consistency_scores (aligned_embeddings ,
208214 datasets = dataset_ids )
209215 between_dataset = [p [0 ] != p [1 ] for p in pairs ]
@@ -303,6 +309,7 @@ def consistency_score(
303309 between : Optional [Literal ["datasets" , "runs" ]] = None ,
304310 labels : Optional [List [Union [npt .NDArray , torch .Tensor ]]] = None ,
305311 dataset_ids : Optional [List [Union [int , str , float ]]] = None ,
312+ num_discretization_bins : int = 100
306313) -> Tuple [npt .NDArray , npt .NDArray , npt .NDArray ]:
307314 """Compute the consistency score between embeddings, either between runs or between datasets.
308315
@@ -320,6 +327,12 @@ def consistency_score(
320327 *Consistency between runs* means the consistency between embeddings obtained from multiple models
321328 trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings
322329 obtained from models trained on **different datasets**, such as different animals, sessions, etc.
330+ num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
331+ for embedding alignment. Also see the ``n_bins`` argument in
332+ :py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
333+ parameter is used internally. This argument is only used if ``labels``
334+ is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels
335+ are continuous and not already discrete.
323336
324337 Returns:
325338 The list of scores computed between the embeddings (first returns), the list of pairs corresponding
@@ -356,12 +369,16 @@ def consistency_score(
356369 if labels is not None :
357370 raise ValueError (
358371 f"No labels should be provided for between-runs consistency." )
359- scores , pairs , datasets = _consistency_runs (embeddings = embeddings ,
360- dataset_ids = dataset_ids )
372+ scores , pairs , datasets = _consistency_runs (
373+ embeddings = embeddings ,
374+ dataset_ids = dataset_ids ,
375+ )
361376 elif between == "datasets" :
362- scores , pairs , datasets = _consistency_datasets (embeddings = embeddings ,
363- dataset_ids = dataset_ids ,
364- labels = labels )
377+ scores , pairs , datasets = _consistency_datasets (
378+ embeddings = embeddings ,
379+ dataset_ids = dataset_ids ,
380+ labels = labels ,
381+ num_discretization_bins = num_discretization_bins )
365382 else :
366383 raise NotImplementedError (
367384 f"Invalid comparison, got between={ between } , expects either datasets or runs."
0 commit comments