Skip to content

Commit 09ad1e4

Browse files
CeliaBenquetstes
andauthored
Fix consistency labels ordering and simplify (#87)
* Add test to repro #27 * Fix consistency plot with permuted labels * Fix consistency labels ordering and simplify * fix typo --------- Co-authored-by: Steffen Schneider <[email protected]>
1 parent 94fa87a commit 09ad1e4

File tree

5 files changed

+235
-156
lines changed

5 files changed

+235
-156
lines changed

cebra/integrations/matplotlib.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,10 @@ def __init__(
524524
self._define_ax(axis)
525525
scores = self._check_array(scores)
526526
# Check the values dimensions
527-
if scores.ndim > 2:
527+
if scores.ndim >= 2:
528528
raise ValueError(
529529
f"Invalid scores dimensions, expect 1D, got {scores.ndim}D.")
530+
530531
self.labels = self._compute_labels(scores,
531532
pairs=pairs,
532533
datasets=datasets)
@@ -609,7 +610,7 @@ def _compute_labels(
609610
"got either both or one of them set to None.")
610611
else:
611612
datasets = self._check_array(datasets)
612-
pairs = self._check_array(pairs)
613+
pairs = self.pairs = self._check_array(pairs)
613614

614615
if len(pairs.shape) == 2:
615616
compared_items = list(sorted(set(pairs[:, 0])))
@@ -651,12 +652,26 @@ def _to_heatmap_format(
651652

652653
values = np.concatenate(values)
653654

655+
pairs = self.pairs
656+
657+
if pairs.ndim == 3:
658+
pairs = pairs[0]
659+
660+
assert len(pairs) == len(values), (self.pairs.shape, len(values))
661+
score_dict = {tuple(pair): value for pair, value in zip(pairs, values)}
662+
654663
if self.labels is None:
655664
n_grid = self.score
656665

657666
heatmap_values = np.zeros((len(self.labels), len(self.labels)))
667+
658668
heatmap_values[:] = float("nan")
659-
heatmap_values[np.eye(len(self.labels)) == 0] = values
669+
for i, label_i in enumerate(self.labels):
670+
for j, label_j in enumerate(self.labels):
671+
if i == j:
672+
heatmap_values[i, j] = float("nan")
673+
else:
674+
heatmap_values[i, j] = score_dict[label_i, label_j]
660675

661676
return np.minimum(heatmap_values * 100, 99)
662677

cebra/integrations/sklearn/metrics.py

Lines changed: 30 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _consistency_datasets(
171171
Returns:
172172
A list of scores obtained between embeddings from different datasets (first element),
173173
a list of pairs of IDs corresponding to the scores (second element), and a list of the
174-
datasets (third element).
174+
dataset IDs (third element).
175175
176176
"""
177177
if labels is None:
@@ -217,7 +217,7 @@ def _consistency_datasets(
217217
pairs = np.array(pairs)[between_dataset]
218218
scores = _average_scores(np.array(scores)[between_dataset], pairs)
219219

220-
return (scores, pairs, datasets)
220+
return (scores, pairs, np.array(dataset_ids))
221221

222222

223223
def _average_scores(scores: Union[npt.NDArray, list], pairs: Union[npt.NDArray,
@@ -246,61 +246,34 @@ def _average_scores(scores: Union[npt.NDArray, list], pairs: Union[npt.NDArray,
246246

247247
def _consistency_runs(
248248
embeddings: List[Union[npt.NDArray, torch.Tensor]],
249-
dataset_ids: Optional[List[Union[int, str, float]]],
250249
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
251250
"""Compute consistency between embeddings coming from the same dataset.
252251
253-
If no `dataset_ids` is provided, then the embeddings are considered to be coming from the
254-
same dataset and consequently not realigned.
255-
256-
For both modes (``between=runs`` or ``between=datasets``), if no `dataset_ids` is provided
257-
(default value is ``None``), then the embeddings are considered individually and the consistency
258-
is computed for possible pairs.
259-
260252
Args:
261253
embeddings: List of embedding matrices.
262-
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
263-
associated to the same dataset.
264254
265255
Returns:
266256
A list of lists of scores obtained between embeddings of the same dataset (first element),
267257
a list of lists of pairs of ids of the embeddings of the same datasets that were compared
268258
(second element), they are identified with :py:class:`numpy.int` from 0 to the number of
269-
embeddings for the dataset, and a list of the datasets (third element).
259+
embeddings for the dataset, and a list of the unique IDs (third element).
270260
"""
271-
# we consider all embeddings as the same dataset
272-
if dataset_ids is None:
273-
datasets = np.array(["unique"])
274-
dataset_ids = ["unique" for i in range(len(embeddings))]
275-
else:
276-
datasets = np.array(sorted(set(dataset_ids)))
277-
278-
within_dataset_scores = []
279-
within_dataset_pairs = []
280-
for dataset in datasets:
281-
# get all embeddings for `dataset`
282-
dataset_embeddings = [
283-
embeddings[i]
284-
for i, dataset_id in enumerate(dataset_ids)
285-
if dataset_id == dataset
286-
]
287-
if len(dataset_embeddings) <= 1:
288-
raise ValueError(
289-
f"Invalid number of embeddings for dataset {dataset}, expect at least 2 embeddings "
290-
f"to be able to compare them, got {len(dataset_embeddings)}")
291-
score, pairs = _consistency_scores(embeddings=dataset_embeddings,
292-
datasets=np.arange(
293-
len(dataset_embeddings)))
294-
within_dataset_scores.append(score)
295-
within_dataset_pairs.append(pairs)
261+
# NOTE(celia): The number of samples of the embeddings should be the same for all as there is
262+
# no realignment, the number of output dimensions can vary between the embeddings we compare.
263+
if not all(embeddings[0].shape[0] == embeddings[i].shape[0]
264+
for i in range(1, len(embeddings))):
265+
raise ValueError(
266+
f"Invalid embeddings, all embeddings should be the same shape to be compared in a between-runs way."
267+
f"If your embeddings are coming from different models, you can use between-datasets"
268+
)
296269

297-
scores = np.array(within_dataset_scores)
298-
pairs = np.array(within_dataset_pairs)
270+
run_ids = np.arange(len(embeddings))
271+
scores, pairs = _consistency_scores(embeddings=embeddings, datasets=run_ids)
299272

300273
return (
301274
_average_scores(scores, pairs),
302-
pairs,
303-
datasets,
275+
np.array(pairs),
276+
np.array(run_ids),
304277
)
305278

306279

@@ -328,15 +301,17 @@ def consistency_score(
328301
trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings
329302
obtained from models trained on **different datasets**, such as different animals, sessions, etc.
330303
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
331-
for embedding alignment. Also see the ``n_bins`` argument in
304+
for embedding alignment. Also see the ``n_bins`` argument in
332305
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
333306
parameter is used internally. This argument is only used if ``labels``
334307
is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels
335308
are continuous and not already discrete.
336309
337310
Returns:
338311
The list of scores computed between the embeddings (first returns), the list of pairs corresponding
339-
to each computed score (second returns) and the list of datasets present in the comparison (third returns).
312+
to each computed score (second returns) and the list of id of the entities present in the comparison,
313+
either different datasets in the between-datasets comparison or runs in the between-runs comparison
314+
(third returns).
340315
341316
Example:
342317
@@ -346,13 +321,13 @@ def consistency_score(
346321
>>> embedding2 = np.random.uniform(0, 1, (1000, 8))
347322
>>> labels1 = np.random.uniform(0, 1, (1000, ))
348323
>>> labels2 = np.random.uniform(0, 1, (1000, ))
349-
>>> # Between-runs, with dataset IDs (optional)
350-
>>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
351-
... dataset_ids=["achilles", "achilles"],
324+
>>> # Between-runs consistency
325+
>>> scores, pairs, ids_runs = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
352326
... between="runs")
353327
>>> # Between-datasets consistency, by aligning on the labels
354-
>>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
328+
>>> scores, pairs, ids_datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
355329
... labels=[labels1, labels2],
330+
... dataset_ids=["achilles", "buddy"],
356331
... between="datasets")
357332
358333
"""
@@ -369,12 +344,13 @@ def consistency_score(
369344
if labels is not None:
370345
raise ValueError(
371346
f"No labels should be provided for between-runs consistency.")
372-
scores, pairs, datasets = _consistency_runs(
373-
embeddings=embeddings,
374-
dataset_ids=dataset_ids,
375-
)
347+
if dataset_ids is not None:
348+
raise ValueError(
349+
f"No dataset ID should be provided for between-runs consistency."
350+
f"All embeddings should be computed on the same dataset.")
351+
scores, pairs, ids = _consistency_runs(embeddings=embeddings,)
376352
elif between == "datasets":
377-
scores, pairs, datasets = _consistency_datasets(
353+
scores, pairs, ids = _consistency_datasets(
378354
embeddings=embeddings,
379355
dataset_ids=dataset_ids,
380356
labels=labels,
@@ -383,4 +359,4 @@ def consistency_score(
383359
raise NotImplementedError(
384360
f"Invalid comparison, got between={between}, expects either datasets or runs."
385361
)
386-
return scores.squeeze(), pairs.squeeze(), datasets
362+
return scores.squeeze(), pairs.squeeze(), ids

docs/source/usage.rst

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -904,47 +904,48 @@ We first create the embeddings to compare: we use two different datasets of data
904904
.. testcode::
905905

906906
n_runs = 3
907+
dataset_ids = ["session1", "session2"]
907908

908909
cebra_model = CEBRA(model_architecture='offset10-model',
909910
batch_size=512,
910911
output_dimension=32,
911912
max_iterations=5,
912913
time_offsets=10)
913914

914-
embeddings, dataset_ids, labels = [], [], []
915+
embeddings_runs = []
916+
embeddings_datasets, ids, labels = [], [], []
915917
for i in range(n_runs):
916-
embeddings.append(cebra_model.fit_transform(neural_session1, continuous_label1))
917-
dataset_ids.append("session1")
918-
labels.append(continuous_label1[:, 0])
918+
embeddings_runs.append(cebra_model.fit_transform(neural_session1, continuous_label1))
919919

920-
embeddings.append(cebra_model.fit_transform(neural_session2, continuous_label2))
921-
dataset_ids.append("session2")
922-
labels.append(continuous_label2[:, 0])
920+
labels.append(continuous_label1[:, 0])
921+
embeddings_datasets.append(embeddings_runs[-1])
923922

924-
n_datasets = len(set(dataset_ids))
923+
embeddings_datasets.append(cebra_model.fit_transform(neural_session2, continuous_label2))
924+
labels.append(continuous_label2[:, 0])
925+
926+
n_datasets = len(dataset_ids)
925927

926928
To get the :py:func:`~.consistency_score` on the set of embeddings that we just generated:
927929

928930
.. testcode::
929931

930-
# Between-runs, with dataset IDs (optional)
931-
scores_runs, pairs_runs, datasets_runs = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
932-
dataset_ids=dataset_ids,
933-
between="runs")
932+
# Between-runs
933+
scores_runs, pairs_runs, ids_runs = cebra.sklearn.metrics.consistency_score(embeddings=embeddings_runs,
934+
between="runs")
934935
assert scores_runs.shape == (n_runs**2 - n_runs, )
935-
assert pairs_runs.shape == (n_datasets, n_runs*n_datasets, 2)
936-
assert datasets_runs.shape == (n_datasets, )
936+
assert pairs_runs.shape == (n_runs**2 - n_runs, 2)
937+
assert ids_runs.shape == (n_runs, )
937938

938939
# Between-datasets, by aligning on the labels
939940
(scores_datasets,
940941
pairs_datasets,
941-
datasets_datasets) = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
942+
ids_datasets) = cebra.sklearn.metrics.consistency_score(embeddings=embeddings_datasets,
942943
labels=labels,
943944
dataset_ids=dataset_ids,
944945
between="datasets")
945946
assert scores_datasets.shape == (n_datasets**2 - n_datasets, )
946-
assert pairs_datasets.shape == (n_runs*(n_runs*n_datasets), 2)
947-
assert datasets_datasets.shape == (n_datasets, )
947+
assert pairs_datasets.shape == (n_datasets**2 - n_datasets, 2)
948+
assert ids_datasets.shape == (n_datasets, )
948949

949950
.. admonition:: See API docs
950951
:class: dropdown
@@ -961,8 +962,8 @@ You can then display the resulting scores using :py:func:`~.plot_consistency`.
961962
ax1 = fig.add_subplot(121)
962963
ax2 = fig.add_subplot(122)
963964

964-
ax1 = cebra.plot_consistency(scores_runs, pairs_runs, datasets_runs, vmin=0, vmax=100, ax=ax1, title="Between-runs consistencies")
965-
ax2 = cebra.plot_consistency(scores_datasets, pairs_datasets, datasets_datasets, vmin=0, vmax=100, ax=ax2, title="Between-subjects consistencies")
965+
ax1 = cebra.plot_consistency(scores_runs, pairs_runs, ids_runs, vmin=0, vmax=100, ax=ax1, title="Between-runs consistencies")
966+
ax2 = cebra.plot_consistency(scores_datasets, pairs_datasets, ids_runs, vmin=0, vmax=100, ax=ax2, title="Between-subjects consistencies")
966967

967968

968969
.. figure:: docs-imgs/consistency-score.png

0 commit comments

Comments
 (0)