Skip to content

Commit 45e2229

Browse files
stesMMathisLabnastya236
authored
Pass showlegend/template kwargs to plotly (#122)
* Pass showlegend/template kwargs to plotly * Update plotly.py - minor fix * fix dropped " * Fix showlegend for plotly * Fix showlegend for plotly * Add handling of continuous labels * Add tests for plotly integration * fix typo --------- Co-authored-by: Mackenzie Mathis <[email protected]> Co-authored-by: Anastasiia Filippova <[email protected]>
1 parent 6a6a07d commit 45e2229

File tree

2 files changed

+89
-23
lines changed

2 files changed

+89
-23
lines changed

cebra/integrations/plotly.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -94,33 +94,55 @@ def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure:
9494
Returns:
9595
The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot.
9696
"""
97-
98-
idx1, idx2, idx3 = self.idx_order
99-
data = [
100-
plotly.graph_objects.Scatter3d(
101-
x=self.embedding[:, idx1],
102-
y=self.embedding[:, idx2],
103-
z=self.embedding[:, idx3],
104-
mode="markers",
105-
marker=dict(
106-
size=self.markersize,
107-
opacity=self.alpha,
108-
color=self.embedding_labels,
109-
colorscale=self.colorscale,
110-
),
111-
)
112-
]
97+
showlegend = kwargs.get("showlegend", False)
98+
discrete = kwargs.get("discrete", False)
11399
col = kwargs.get("col", None)
114100
row = kwargs.get("row", None)
101+
template = kwargs.get("template", "plotly_white")
102+
data = []
115103

116-
if col is None or row is None:
117-
self.axis.add_trace(data[0])
104+
if not discrete and showlegend:
105+
raise ValueError("Cannot show legend with continuous labels.")
106+
107+
idx1, idx2, idx3 = self.idx_order
108+
109+
if discrete:
110+
unique_labels = np.unique(self.embedding_labels)
118111
else:
119-
self.axis.add_trace(data[0], row=row, col=col)
112+
unique_labels = [self.embedding_labels]
113+
114+
for label in unique_labels:
115+
if discrete:
116+
filtered_idx = [
117+
i for i, x in enumerate(self.embedding_labels) if x == label
118+
]
119+
else:
120+
filtered_idx = np.arange(self.embedding.shape[0])
121+
data.append(
122+
plotly.graph_objects.Scatter3d(x=self.embedding[filtered_idx,
123+
idx1],
124+
y=self.embedding[filtered_idx,
125+
idx2],
126+
z=self.embedding[filtered_idx,
127+
idx3],
128+
mode="markers",
129+
marker=dict(
130+
size=self.markersize,
131+
opacity=self.alpha,
132+
color=label,
133+
colorscale=self.colorscale,
134+
),
135+
name=str(label)))
136+
137+
for trace in data:
138+
if col is None or row is None:
139+
self.axis.add_trace(trace)
140+
else:
141+
self.axis.add_trace(trace, row=row, col=col)
120142

121143
self.axis.update_layout(
122-
template="plotly_white",
123-
showlegend=False,
144+
template=template,
145+
showlegend=showlegend,
124146
title=self.title,
125147
)
126148

@@ -166,8 +188,17 @@ def plot_embedding_interactive(
166188
title: The title on top of the embedding.
167189
figsize: Figure width and height in inches.
168190
dpi: Figure resolution.
169-
kwargs: Optional arguments to customize the plots. See :py:class:`plotly.graph_objects.Scatter` documentation for more
170-
details on which arguments to use.
191+
kwargs: Optional arguments to customize the plots. This dictionary includes the following optional arguments:
192+
-- showlegend: Whether to show the legend or not.
193+
-- discrete: Whether the labels are discrete or not.
194+
-- col: The column of the subplot to plot the embedding on.
195+
-- row: The row of the subplot to plot the embedding on.
196+
-- template: The template to use for the plot.
197+
198+
Note: showlegend can be True only if discrete is True.
199+
200+
See :py:class:`plotly.graph_objects.Scatter` documentation for more
201+
details on which arguments to use.
171202
172203
Returns:
173204
The plotly figure.

tests/test_plotly.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,38 @@ def test_plot_embedding(output_dimension, idx_order):
8484

8585
fig_subplots.data = []
8686
fig_subplots.layout = {}
87+
88+
89+
def test_discrete_with_legend():
90+
embedding = np.random.uniform(0, 1, (1000, 3))
91+
labels = np.random.randint(0, 10, (1000,))
92+
93+
fig = cebra_plotly.plot_embedding_interactive(embedding,
94+
labels,
95+
discrete=True,
96+
showlegend=True)
97+
98+
assert len(fig._data_objs) == np.unique(labels).shape[0]
99+
assert isinstance(fig, go.Figure)
100+
101+
102+
def test_continuous_no_legend():
103+
embedding = np.random.uniform(0, 1, (1000, 3))
104+
labels = np.random.uniform(0, 1, (1000,))
105+
106+
fig = cebra_plotly.plot_embedding_interactive(embedding, labels)
107+
108+
assert len(fig._data_objs) == 1
109+
110+
assert isinstance(fig, go.Figure)
111+
112+
113+
def test_continuous_with_legend_raises_error():
114+
embedding = np.random.uniform(0, 1, (1000, 3))
115+
labels = np.random.uniform(0, 1, (1000,))
116+
117+
with pytest.raises(ValueError):
118+
cebra_plotly.plot_embedding_interactive(embedding,
119+
labels,
120+
discrete=False,
121+
showlegend=True)

0 commit comments

Comments
 (0)