@@ -94,33 +94,55 @@ def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure:
9494 Returns:
9595 The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot.
9696 """
97-
98- idx1 , idx2 , idx3 = self .idx_order
99- data = [
100- plotly .graph_objects .Scatter3d (
101- x = self .embedding [:, idx1 ],
102- y = self .embedding [:, idx2 ],
103- z = self .embedding [:, idx3 ],
104- mode = "markers" ,
105- marker = dict (
106- size = self .markersize ,
107- opacity = self .alpha ,
108- color = self .embedding_labels ,
109- colorscale = self .colorscale ,
110- ),
111- )
112- ]
97+ showlegend = kwargs .get ("showlegend" , False )
98+ discrete = kwargs .get ("discrete" , False )
11399 col = kwargs .get ("col" , None )
114100 row = kwargs .get ("row" , None )
101+ template = kwargs .get ("template" , "plotly_white" )
102+ data = []
115103
116- if col is None or row is None :
117- self .axis .add_trace (data [0 ])
104+ if not discrete and showlegend :
105+ raise ValueError ("Cannot show legend with continuous labels." )
106+
107+ idx1 , idx2 , idx3 = self .idx_order
108+
109+ if discrete :
110+ unique_labels = np .unique (self .embedding_labels )
118111 else :
119- self .axis .add_trace (data [0 ], row = row , col = col )
112+ unique_labels = [self .embedding_labels ]
113+
114+ for label in unique_labels :
115+ if discrete :
116+ filtered_idx = [
117+ i for i , x in enumerate (self .embedding_labels ) if x == label
118+ ]
119+ else :
120+ filtered_idx = np .arange (self .embedding .shape [0 ])
121+ data .append (
122+ plotly .graph_objects .Scatter3d (x = self .embedding [filtered_idx ,
123+ idx1 ],
124+ y = self .embedding [filtered_idx ,
125+ idx2 ],
126+ z = self .embedding [filtered_idx ,
127+ idx3 ],
128+ mode = "markers" ,
129+ marker = dict (
130+ size = self .markersize ,
131+ opacity = self .alpha ,
132+ color = label ,
133+ colorscale = self .colorscale ,
134+ ),
135+ name = str (label )))
136+
137+ for trace in data :
138+ if col is None or row is None :
139+ self .axis .add_trace (trace )
140+ else :
141+ self .axis .add_trace (trace , row = row , col = col )
120142
121143 self .axis .update_layout (
122- template = "plotly_white" ,
123- showlegend = False ,
144+ template = template ,
145+ showlegend = showlegend ,
124146 title = self .title ,
125147 )
126148
@@ -166,8 +188,17 @@ def plot_embedding_interactive(
166188 title: The title on top of the embedding.
167189 figsize: Figure width and height in inches.
168190 dpi: Figure resolution.
169- kwargs: Optional arguments to customize the plots. See :py:class:`plotly.graph_objects.Scatter` documentation for more
170- details on which arguments to use.
191+ kwargs: Optional arguments to customize the plots. This dictionary includes the following optional arguments:
192+ -- showlegend: Whether to show the legend or not.
193+ -- discrete: Whether the labels are discrete or not.
194+ -- col: The column of the subplot to plot the embedding on.
195+ -- row: The row of the subplot to plot the embedding on.
196+ -- template: The template to use for the plot.
197+
198+ Note: showlegend can be True only if discrete is True.
199+
200+ See :py:class:`plotly.graph_objects.Scatter` documentation for more
201+ details on which arguments to use.
171202
172203 Returns:
173204 The plotly figure.
0 commit comments