@@ -171,7 +171,7 @@ def _consistency_datasets(
171171 Returns:
172172 A list of scores obtained between embeddings from different datasets (first element),
173173 a list of pairs of IDs corresponding to the scores (second element), and a list of the
174- datasets (third element).
174+ dataset IDs (third element).
175175
176176 """
177177 if labels is None :
@@ -217,7 +217,7 @@ def _consistency_datasets(
217217 pairs = np .array (pairs )[between_dataset ]
218218 scores = _average_scores (np .array (scores )[between_dataset ], pairs )
219219
220- return (scores , pairs , datasets )
220+ return (scores , pairs , np . array ( dataset_ids ) )
221221
222222
223223def _average_scores (scores : Union [npt .NDArray , list ], pairs : Union [npt .NDArray ,
@@ -246,61 +246,34 @@ def _average_scores(scores: Union[npt.NDArray, list], pairs: Union[npt.NDArray,
246246
247247def _consistency_runs (
248248 embeddings : List [Union [npt .NDArray , torch .Tensor ]],
249- dataset_ids : Optional [List [Union [int , str , float ]]],
250249) -> Tuple [npt .NDArray , npt .NDArray , npt .NDArray ]:
251250 """Compute consistency between embeddings coming from the same dataset.
252251
253- If no `dataset_ids` is provided, then the embeddings are considered to be coming from the
254- same dataset and consequently not realigned.
255-
256- For both modes (``between=runs`` or ``between=datasets``), if no `dataset_ids` is provided
257- (default value is ``None``), then the embeddings are considered individually and the consistency
258- is computed for possible pairs.
259-
260252 Args:
261253 embeddings: List of embedding matrices.
262- dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
263- associated to the same dataset.
264254
265255 Returns:
266256 A list of lists of scores obtained between embeddings of the same dataset (first element),
267257 a list of lists of pairs of ids of the embeddings of the same datasets that were compared
268258 (second element), they are identified with :py:class:`numpy.int` from 0 to the number of
269- embeddings for the dataset, and a list of the datasets (third element).
259+ embeddings for the dataset, and a list of the unique IDs (third element).
270260 """
271- # we consider all embeddings as the same dataset
272- if dataset_ids is None :
273- datasets = np .array (["unique" ])
274- dataset_ids = ["unique" for i in range (len (embeddings ))]
275- else :
276- datasets = np .array (sorted (set (dataset_ids )))
277-
278- within_dataset_scores = []
279- within_dataset_pairs = []
280- for dataset in datasets :
281- # get all embeddings for `dataset`
282- dataset_embeddings = [
283- embeddings [i ]
284- for i , dataset_id in enumerate (dataset_ids )
285- if dataset_id == dataset
286- ]
287- if len (dataset_embeddings ) <= 1 :
288- raise ValueError (
289- f"Invalid number of embeddings for dataset { dataset } , expect at least 2 embeddings "
290- f"to be able to compare them, got { len (dataset_embeddings )} " )
291- score , pairs = _consistency_scores (embeddings = dataset_embeddings ,
292- datasets = np .arange (
293- len (dataset_embeddings )))
294- within_dataset_scores .append (score )
295- within_dataset_pairs .append (pairs )
261+ # NOTE(celia): The number of samples of the embeddings should be the same for all as there is
262+ # no realignment, the number of output dimensions can vary between the embeddings we compare.
263+ if not all (embeddings [0 ].shape [0 ] == embeddings [i ].shape [0 ]
264+ for i in range (1 , len (embeddings ))):
265+ raise ValueError (
266+ f"Invalid embeddings, all embeddings should be the same shape to be compared in a between-runs way."
267+ f"If your embeddings are coming from different models, you can use between-datasets"
268+ )
296269
297- scores = np .array ( within_dataset_scores )
298- pairs = np . array ( within_dataset_pairs )
270+ run_ids = np .arange ( len ( embeddings ) )
271+ scores , pairs = _consistency_scores ( embeddings = embeddings , datasets = run_ids )
299272
300273 return (
301274 _average_scores (scores , pairs ),
302- pairs ,
303- datasets ,
275+ np . array ( pairs ) ,
276+ np . array ( run_ids ) ,
304277 )
305278
306279
@@ -328,15 +301,17 @@ def consistency_score(
328301 trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings
329302 obtained from models trained on **different datasets**, such as different animals, sessions, etc.
330303 num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
331- for embedding alignment. Also see the ``n_bins`` argument in
304+ for embedding alignment. Also see the ``n_bins`` argument in
332305 :py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
333306 parameter is used internally. This argument is only used if ``labels``
334307 is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels
335308 are continuous and not already discrete.
336309
337310 Returns:
338311 The list of scores computed between the embeddings (first returns), the list of pairs corresponding
339- to each computed score (second returns) and the list of datasets present in the comparison (third returns).
312+ to each computed score (second returns) and the list of id of the entities present in the comparison,
313+ either different datasets in the between-datasets comparison or runs in the between-runs comparison
314+ (third returns).
340315
341316 Example:
342317
@@ -346,13 +321,13 @@ def consistency_score(
346321 >>> embedding2 = np.random.uniform(0, 1, (1000, 8))
347322 >>> labels1 = np.random.uniform(0, 1, (1000, ))
348323 >>> labels2 = np.random.uniform(0, 1, (1000, ))
349- >>> # Between-runs, with dataset IDs (optional)
350- >>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
351- ... dataset_ids=["achilles", "achilles"],
324+ >>> # Between-runs consistency
325+ >>> scores, pairs, ids_runs = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
352326 ... between="runs")
353327 >>> # Between-datasets consistency, by aligning on the labels
354- >>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
328+ >>> scores, pairs, ids_datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
355329 ... labels=[labels1, labels2],
330+ ... dataset_ids=["achilles", "buddy"],
356331 ... between="datasets")
357332
358333 """
@@ -369,12 +344,13 @@ def consistency_score(
369344 if labels is not None :
370345 raise ValueError (
371346 f"No labels should be provided for between-runs consistency." )
372- scores , pairs , datasets = _consistency_runs (
373- embeddings = embeddings ,
374- dataset_ids = dataset_ids ,
375- )
347+ if dataset_ids is not None :
348+ raise ValueError (
349+ f"No dataset ID should be provided for between-runs consistency."
350+ f"All embeddings should be computed on the same dataset." )
351+ scores , pairs , ids = _consistency_runs (embeddings = embeddings ,)
376352 elif between == "datasets" :
377- scores , pairs , datasets = _consistency_datasets (
353+ scores , pairs , ids = _consistency_datasets (
378354 embeddings = embeddings ,
379355 dataset_ids = dataset_ids ,
380356 labels = labels ,
@@ -383,4 +359,4 @@ def consistency_score(
383359 raise NotImplementedError (
384360 f"Invalid comparison, got between={ between } , expects either datasets or runs."
385361 )
386- return scores .squeeze (), pairs .squeeze (), datasets
362+ return scores .squeeze (), pairs .squeeze (), ids
0 commit comments