|
| 1 | +# |
| 2 | +# (c) All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE, |
| 3 | +# Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and |
| 4 | +# original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023. |
| 5 | +# |
| 6 | +# Source code: |
| 7 | +# https://github.com/AdaptiveMotorControlLab/CEBRA |
| 8 | +# |
| 9 | +# Please see LICENSE.md for the full license document: |
| 10 | +# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md |
| 11 | +# |
| 12 | +"""Plotly interface to CEBRA.""" |
| 13 | +from typing import Optional, Tuple, Union |
| 14 | + |
| 15 | +import matplotlib.colors |
| 16 | +import numpy as np |
| 17 | +import numpy.typing as npt |
| 18 | +import plotly.graph_objects |
| 19 | +import torch |
| 20 | + |
| 21 | +from cebra.integrations.matplotlib import _EmbeddingPlot |
| 22 | + |
| 23 | + |
| 24 | +def _convert_cmap2colorscale(cmap: str, pl_entries: int = 11, rdigits: int = 2): |
| 25 | + """Convert matplotlib colormap to plotly colorscale. |
| 26 | +
|
| 27 | + Args: |
| 28 | + cmap: A registered colormap name from matplotlib. |
| 29 | + pl_entries: Number of colors to use in the plotly colorscale. |
| 30 | + rdigits: Number of digits to round the colorscale to. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + pl_colorscale: List of scaled colors to plot the embeddings |
| 34 | + """ |
| 35 | + scale = np.linspace(0, 1, pl_entries) |
| 36 | + colors = (cmap(scale)[:, :3] * 255).astype(np.uint8) |
| 37 | + pl_colorscale = [[round(s, rdigits), f"rgb{tuple(color)}"] |
| 38 | + for s, color in zip(scale, colors)] |
| 39 | + return pl_colorscale |
| 40 | + |
| 41 | + |
| 42 | +class _EmbeddingInteractivePlot(_EmbeddingPlot): |
| 43 | + |
| 44 | + def __init__(self, **kwargs): |
| 45 | + self.figsize = kwargs.get("figsize") |
| 46 | + super().__init__(**kwargs) |
| 47 | + self.colorscale = self._define_colorscale(self.cmap) |
| 48 | + |
| 49 | + def _define_ax(self, axis: Optional[plotly.graph_objects.Figure]): |
| 50 | + """Define the axis of the plot. |
| 51 | +
|
| 52 | + Args: |
| 53 | + axis: Optional axis to create the plot on. |
| 54 | +
|
| 55 | + Returns: |
| 56 | + axis: The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot. |
| 57 | + """ |
| 58 | + |
| 59 | + if axis is None: |
| 60 | + self.axis = plotly.graph_objects.Figure( |
| 61 | + layout=plotly.graph_objects.Layout(height=100 * self.figsize[0], |
| 62 | + width=100 * self.figsize[1])) |
| 63 | + |
| 64 | + else: |
| 65 | + self.axis = axis |
| 66 | + |
| 67 | + def _define_colorscale(self, cmap: str): |
| 68 | + """Specify the cmap for plotting the latent space. |
| 69 | +
|
| 70 | + Args: |
| 71 | + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). |
| 72 | +
|
| 73 | +
|
| 74 | + Returns: |
| 75 | + colorscale: List of scaled colors to plot the embeddings |
| 76 | + """ |
| 77 | + colorscale = _convert_cmap2colorscale(matplotlib.cm.get_cmap(cmap)) |
| 78 | + |
| 79 | + return colorscale |
| 80 | + |
| 81 | + def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure: |
| 82 | + """Plot the embedding in 3d. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot. |
| 86 | + """ |
| 87 | + |
| 88 | + idx1, idx2, idx3 = self.idx_order |
| 89 | + data = [ |
| 90 | + plotly.graph_objects.Scatter3d( |
| 91 | + x=self.embedding[:, idx1], |
| 92 | + y=self.embedding[:, idx2], |
| 93 | + z=self.embedding[:, idx3], |
| 94 | + mode="markers", |
| 95 | + marker=dict( |
| 96 | + size=self.markersize, |
| 97 | + opacity=self.alpha, |
| 98 | + color=self.embedding_labels, |
| 99 | + colorscale=self.colorscale, |
| 100 | + ), |
| 101 | + ) |
| 102 | + ] |
| 103 | + col = kwargs.get("col", None) |
| 104 | + row = kwargs.get("row", None) |
| 105 | + |
| 106 | + if col is None or row is None: |
| 107 | + self.axis.add_trace(data[0]) |
| 108 | + else: |
| 109 | + self.axis.add_trace(data[0], row=row, col=col) |
| 110 | + |
| 111 | + self.axis.update_layout( |
| 112 | + template="plotly_white", |
| 113 | + showlegend=False, |
| 114 | + title=self.title, |
| 115 | + ) |
| 116 | + |
| 117 | + return self.axis |
| 118 | + |
| 119 | + |
| 120 | +def plot_embedding_interactive( |
| 121 | + embedding: Union[npt.NDArray, torch.Tensor], |
| 122 | + embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey", |
| 123 | + axis: Optional[plotly.graph_objects.Figure] = None, |
| 124 | + markersize: float = 1, |
| 125 | + idx_order: Optional[Tuple[int]] = None, |
| 126 | + alpha: float = 0.4, |
| 127 | + cmap: str = "cool", |
| 128 | + title: str = "Embedding", |
| 129 | + figsize: Tuple[int] = (5, 5), |
| 130 | + dpi: int = 100, |
| 131 | + **kwargs, |
| 132 | +) -> plotly.graph_objects.Figure: |
| 133 | + """Plot embedding in a 3D dimensional space. |
| 134 | +
|
| 135 | + This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of |
| 136 | + dimensions of the embedding (i.e., between 0 and :py:attr:`cebra.CEBRA.output_dimension` -1). |
| 137 | +
|
| 138 | + The function makes use of :py:func:`plotly.graph_objs._scatter.Scatter` and parameters from that function can be provided |
| 139 | + as part of ``kwargs``. |
| 140 | +
|
| 141 | +
|
| 142 | + Args: |
| 143 | + embedding: A matrix containing the feature representation computed with CEBRA. |
| 144 | + embedding_labels: The labels used to map the data to color. It can be: |
| 145 | +
|
| 146 | + * A vector that is the same sample size as the embedding, associating a value to each of the sample, either discrete or continuous. |
| 147 | + * A string, either `time`, then the labels while color the embedding based on temporality, or a string that can be interpreted as a RGB(A) color, then the embedding will be uniformly display with that unique color. |
| 148 | + axis: Optional axis to create the plot on. |
| 149 | + idx_order: A tuple (x, y, z) or (x, y) that maps a dimension in the data to a dimension in the 3D/2D |
| 150 | + embedding. The simplest form is (0, 1, 2) or (0, 1) but one might want to plot either those |
| 151 | + dimensions differently (e.g., (1, 0, 2)) or other dimensions from the feature representation |
| 152 | + (e.g., (2, 4, 5)). |
| 153 | + markersize: The marker size. |
| 154 | + alpha: The marker blending, between 0 (transparent) and 1 (opaque). |
| 155 | + cmap: The Colormap instance or registered colormap name used to map scalar data to colors. It will be ignored if `embedding_labels` is set to a valid RGB(A). |
| 156 | + title: The title on top of the embedding. |
| 157 | + figsize: Figure width and height in inches. |
| 158 | + dpi: Figure resolution. |
| 159 | + kwargs: Optional arguments to customize the plots. See :py:func:`plotly.graph_objs._scatter.Scatter` documentation for more |
| 160 | + details on which arguments to use. |
| 161 | +
|
| 162 | + Returns: |
| 163 | + The plotly figure. |
| 164 | +
|
| 165 | +
|
| 166 | + Example: |
| 167 | +
|
| 168 | + >>> import cebra |
| 169 | + >>> import numpy as np |
| 170 | + >>> X = np.random.uniform(0, 1, (100, 50)) |
| 171 | + >>> y = np.random.uniform(0, 10, (100, 5)) |
| 172 | + >>> cebra_model = cebra.CEBRA(max_iterations=10) |
| 173 | + >>> cebra_model.fit(X, y) |
| 174 | + CEBRA(max_iterations=10) |
| 175 | + >>> embedding = cebra_model.transform(X) |
| 176 | + >>> cebra_time = np.arange(X.shape[0]) |
| 177 | + >>> fig = cebra.integrations.plotly.plot_embedding_interactive(embedding, embedding_labels=cebra_time) |
| 178 | +
|
| 179 | + """ |
| 180 | + return _EmbeddingInteractivePlot( |
| 181 | + embedding=embedding, |
| 182 | + embedding_labels=embedding_labels, |
| 183 | + axis=axis, |
| 184 | + idx_order=idx_order, |
| 185 | + markersize=markersize, |
| 186 | + alpha=alpha, |
| 187 | + cmap=cmap, |
| 188 | + title=title, |
| 189 | + figsize=figsize, |
| 190 | + dpi=dpi, |
| 191 | + ).plot(**kwargs) |
0 commit comments