Skip to content

Commit 9576a5a

Browse files
stesnastya236
andauthored
Add interactive plots using plotly (#82)
* Added interactive plots * Add plotly integration * Fix during testing, add header * Add plotly to integrations * Delete plotly import from matplotlib * Add plotly to optdepends * Add tests and update Makefile * Checking whether pass test with required plotly * Final check * Delete redundent code * Delete plotly from essential dependencies * Final changes and test to interactive plot * Fix docs * apply pre-commit * Adapt docstring * Add integrations to test dependencies * fixed an issue with figsize * Change default params, delete typo * Fix example in plotly --------- Co-authored-by: nastya236 <[email protected]>
1 parent 09ad1e4 commit 9576a5a

File tree

6 files changed

+264
-10
lines changed

6 files changed

+264
-10
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Python package
33
on:
44
push:
55
branches:
6-
- main
6+
- main
77
pull_request:
88
branches:
99
- main
@@ -53,7 +53,7 @@ jobs:
5353
run: |
5454
python -m pip install --upgrade pip setuptools wheel
5555
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
56-
pip install '.[dev,datasets]'
56+
pip install '.[dev,datasets,integrations]'
5757
5858
- name: Run the formatter
5959
run: |

PKGBUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ optdepends=(
2424
python-matplotlib
2525
python-h5py
2626
python-argparse
27+
python-plotly
2728
)
2829
license=('custom')
2930
arch=('any')

cebra/integrations/matplotlib.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
1111
#
1212
"""Matplotlib interface to CEBRA."""
13-
1413
import abc
1514
from collections.abc import Iterable
1615
from typing import List, Literal, Optional, Tuple, Union
1716

18-
import matplotlib
1917
import matplotlib.axes
2018
import matplotlib.cm
2119
import matplotlib.colors
@@ -483,8 +481,8 @@ def plot(self, **kwargs) -> matplotlib.axes.Axes:
483481
self.ax = self._plot_3d(**kwargs)
484482
else:
485483
self.ax = self._plot_2d(**kwargs)
486-
487-
self.ax.set_title(self.title)
484+
if isinstance(self.ax, matplotlib.axes._axes.Axes):
485+
self.ax.set_title(self.title)
488486

489487
return self.ax
490488

@@ -751,10 +749,8 @@ def plot_overview(
751749
figsize: tuple = (15, 4),
752750
dpi: int = 100,
753751
**kwargs,
754-
) -> Tuple[
755-
matplotlib.figure.Figure,
756-
Tuple[matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes],
757-
]:
752+
) -> Tuple[matplotlib.figure.Figure, Tuple[
753+
matplotlib.axes.Axes, matplotlib.axes.Axes, matplotlib.axes.Axes]]:
758754
"""Plot an overview of a trained CEBRA model.
759755
760756
Args:

cebra/integrations/plotly.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#
2+
# (c) All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE,
3+
# Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and
4+
# original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023.
5+
#
6+
# Source code:
7+
# https://github.com/AdaptiveMotorControlLab/CEBRA
8+
#
9+
# Please see LICENSE.md for the full license document:
10+
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
11+
#
12+
"""Plotly interface to CEBRA."""
13+
from typing import Optional, Tuple, Union
14+
15+
import matplotlib.colors
16+
import numpy as np
17+
import numpy.typing as npt
18+
import plotly.graph_objects
19+
import torch
20+
21+
from cebra.integrations.matplotlib import _EmbeddingPlot
22+
23+
24+
def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2):
25+
"""Convert matplotlib colormap to plotly colorscale.
26+
27+
Args:
28+
cmap: A registered colormap name from matplotlib.
29+
pl_entries: Number of colors to use in the plotly colorscale.
30+
rdigits: Number of digits to round the colorscale to.
31+
32+
Returns:
33+
pl_colorscale: List of scaled colors to plot the embeddings
34+
"""
35+
scale = np.linspace(0, 1, pl_entries)
36+
colors = (cmap(scale)[:, :3] * 255).astype(np.uint8)
37+
pl_colorscale = [[round(s, rdigits), f"rgb{tuple(color)}"]
38+
for s, color in zip(scale, colors)]
39+
return pl_colorscale
40+
41+
42+
class _EmbeddingInteractivePlot(_EmbeddingPlot):
43+
44+
def __init__(self, **kwargs):
45+
self.figsize = kwargs.get("figsize")
46+
super().__init__(**kwargs)
47+
self.colorscale = self._define_colorscale(self.cmap)
48+
49+
def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]):
50+
"""Define the axis of the plot.
51+
52+
Args:
53+
axis: Optional axis to create the plot on.
54+
55+
Returns:
56+
axis: The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot.
57+
"""
58+
59+
if axis is None:
60+
self.axis = plotly.graph_objects.Figure(
61+
layout=plotly.graph_objects.Layout(height=100 * self.figsize[0],
62+
width=100 * self.figsize[1]))
63+
64+
else:
65+
self.axis = axis
66+
67+
def _define_colorscale(self, cmap: str):
68+
"""Specify the cmap for plotting the latent space.
69+
70+
Args:
71+
cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A).
72+
73+
74+
Returns:
75+
colorscale: List of scaled colors to plot the embeddings
76+
"""
77+
colorscale = _convert_cmap2colorscale(matplotlib.cm.get_cmap(cmap))
78+
79+
return colorscale
80+
81+
def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure:
82+
"""Plot the embedding in 3d.
83+
84+
Returns:
85+
The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot.
86+
"""
87+
88+
idx1, idx2, idx3 = self.idx_order
89+
data = [
90+
plotly.graph_objects.Scatter3d(
91+
x=self.embedding[:, idx1],
92+
y=self.embedding[:, idx2],
93+
z=self.embedding[:, idx3],
94+
mode="markers",
95+
marker=dict(
96+
size=self.markersize,
97+
opacity=self.alpha,
98+
color=self.embedding_labels,
99+
colorscale=self.colorscale,
100+
),
101+
)
102+
]
103+
col = kwargs.get("col", None)
104+
row = kwargs.get("row", None)
105+
106+
if col is None or row is None:
107+
self.axis.add_trace(data[0])
108+
else:
109+
self.axis.add_trace(data[0], row=row, col=col)
110+
111+
self.axis.update_layout(
112+
template="plotly_white",
113+
showlegend=False,
114+
title=self.title,
115+
)
116+
117+
return self.axis
118+
119+
120+
def plot_embedding_interactive(
121+
embedding: Union[npt.NDArray, torch.Tensor],
122+
embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey",
123+
axis: Optional[plotly.graph_objects.Figure] = None,
124+
markersize: float = 1,
125+
idx_order: Optional[Tuple[int]] = None,
126+
alpha: float = 0.4,
127+
cmap: str = "cool",
128+
title: str = "Embedding",
129+
figsize: Tuple[int] = (5, 5),
130+
dpi: int = 100,
131+
**kwargs,
132+
) -> plotly.graph_objects.Figure:
133+
"""Plot embedding in a 3D dimensional space.
134+
135+
This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of
136+
dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1).
137+
138+
The function makes use of :py:func:`plotly.graph_objs._scatter.Scatter` and parameters from that function can be provided
139+
as part of ``kwargs``.
140+
141+
142+
Args:
143+
embedding: A matrix containing the feature representation computed with CEBRA.
144+
embedding_labels: The labels used to map the data to color. It can be:
145+
146+
* A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous.
147+
* A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color.
148+
axis: Optional axis to create the plot on.
149+
idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D
150+
embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those
151+
dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation
152+
(e.g., (2, 4, 5)).
153+
markersize: The marker size.
154+
alpha: The marker blending, between 0 (transparent) and 1 (opaque).
155+
cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A).
156+
title: The title on top of the embedding.
157+
figsize: Figure width and height in inches.
158+
dpi: Figure resolution.
159+
kwargs: Optional arguments to customize the plots. See :py:func:`plotly.graph_objs._scatter.Scatter` documentation for more
160+
details on which arguments to use.
161+
162+
Returns:
163+
The plotly figure.
164+
165+
166+
Example:
167+
168+
>>> import cebra
169+
>>> import numpy as np
170+
>>> X = np.random.uniform(0, 1, (100, 50))
171+
>>> y = np.random.uniform(0, 10, (100, 5))
172+
>>> cebra_model = cebra.CEBRA(max_iterations=10)
173+
>>> cebra_model.fit(X, y)
174+
CEBRA(max_iterations=10)
175+
>>> embedding = cebra_model.transform(X)
176+
>>> cebra_time = np.arange(X.shape[0])
177+
>>> fig = cebra.integrations.plotly.plot_embedding_interactive(embedding, embedding_labels=cebra_time)
178+
179+
"""
180+
return _EmbeddingInteractivePlot(
181+
embedding=embedding,
182+
embedding_labels=embedding_labels,
183+
axis=axis,
184+
idx_order=idx_order,
185+
markersize=markersize,
186+
alpha=alpha,
187+
cmap=cmap,
188+
title=title,
189+
figsize=figsize,
190+
dpi=dpi,
191+
).plot(**kwargs)

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ datasets =
4848
integrations =
4949
jupyter
5050
pandas
51+
plotly
5152
docs =
5253
sphinx==5.3
5354
sphinx-gallery==0.10.1

tests/test_plotly.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import matplotlib
2+
import numpy as np
3+
import plotly.graph_objects as go
4+
import pytest
5+
from plotly.subplots import make_subplots
6+
7+
import cebra.integrations.plotly as cebra_plotly
8+
import cebra.integrations.sklearn.cebra as cebra_sklearn_cebra
9+
10+
11+
@pytest.mark.parametrize("cmap", ["viridis", "plasma", "inferno", "magma"])
12+
def test_colorscale(cmap):
13+
cmap = matplotlib.cm.get_cmap(cmap)
14+
colorscale = cebra_plotly._convert_cmap2colorscale(cmap)
15+
assert isinstance(colorscale, list)
16+
17+
18+
@pytest.mark.parametrize("output_dimension, idx_order", [(8, (2, 3, 4)),
19+
(3, (0, 1, 2))])
20+
def test_plot_embedding(output_dimension, idx_order):
21+
# example dataset
22+
X = np.random.uniform(0, 1, (1000, 50))
23+
y = np.random.uniform(0, 1, (1000,))
24+
25+
# integration tests
26+
model = cebra_sklearn_cebra.CEBRA(max_iterations=10,
27+
batch_size=512,
28+
output_dimension=output_dimension)
29+
30+
model.fit(X)
31+
embedding = model.transform(X)
32+
33+
fig = cebra_plotly.plot_embedding_interactive(embedding=embedding,
34+
embedding_labels=y)
35+
assert isinstance(fig, go.Figure)
36+
assert len(fig.data) == 1
37+
38+
fig.layout = {}
39+
fig.data = []
40+
41+
fig_subplots = make_subplots(
42+
rows=2,
43+
cols=2,
44+
specs=[
45+
[{
46+
"type": "scatter3d"
47+
}, {
48+
"type": "scatter3d"
49+
}],
50+
[{
51+
"type": "scatter3d"
52+
}, {
53+
"type": "scatter3d"
54+
}],
55+
],
56+
)
57+
58+
fig_subplots = cebra_plotly.plot_embedding_interactive(axis=fig_subplots,
59+
embedding=embedding,
60+
embedding_labels=y,
61+
row=1,
62+
col=1)
63+
64+
fig_subplots.data = []
65+
fig_subplots.layout = {}

0 commit comments

Comments
 (0)