|
10 | 10 |
|
11 | 11 | from .base_learner import BaseLearner
|
12 | 12 |
|
13 |
| -from ..notebook_integration import ensure_holoviews |
| 13 | +from ..notebook_integration import ensure_holoviews, ensure_plotly |
14 | 14 | from .triangulation import (Triangulation, point_in_simplex,
|
15 | 15 | circumsphere, simplex_volume_in_embedding)
|
16 | 16 | from ..utils import restore, cache_latest
|
@@ -585,6 +585,69 @@ def plot_slice(self, cut_mapping, n=None):
|
585 | 585 | else:
|
586 | 586 | raise ValueError("Only 1 or 2-dimensional plots can be generated.")
|
587 | 587 |
|
| 588 | + def plot_3D(self, with_triangulation=False): |
| 589 | + """Plot the learner's data in 3D using plotly. |
| 590 | +
|
| 591 | + Parameters |
| 592 | + ---------- |
| 593 | + with_triangulation : bool, default: False |
| 594 | + Add the verticices to the plot. |
| 595 | +
|
| 596 | + Returns |
| 597 | + ------- |
| 598 | + plot : plotly.offline.iplot object |
| 599 | + The 3D plot of ``learner.data``. |
| 600 | + """ |
| 601 | + plotly = ensure_plotly() |
| 602 | + |
| 603 | + plots = [] |
| 604 | + |
| 605 | + vertices = self.tri.vertices |
| 606 | + if with_triangulation: |
| 607 | + Xe, Ye, Ze = [], [], [] |
| 608 | + for simplex in self.tri.simplices: |
| 609 | + for s in itertools.combinations(simplex, 2): |
| 610 | + Xe += [vertices[i][0] for i in s] + [None] |
| 611 | + Ye += [vertices[i][1] for i in s] + [None] |
| 612 | + Ze += [vertices[i][2] for i in s] + [None] |
| 613 | + |
| 614 | + plots.append(plotly.graph_objs.Scatter3d( |
| 615 | + x=Xe, y=Ye, z=Ze, mode='lines', |
| 616 | + line=dict(color='rgb(125,125,125)', width=1), |
| 617 | + hoverinfo='none' |
| 618 | + )) |
| 619 | + |
| 620 | + Xn, Yn, Zn = zip(*vertices) |
| 621 | + colors = [self.data[p] for p in self.tri.vertices] |
| 622 | + marker = dict(symbol='circle', size=3, color=colors, |
| 623 | + colorscale='Viridis', |
| 624 | + line=dict(color='rgb(50,50,50)', width=0.5)) |
| 625 | + |
| 626 | + plots.append(plotly.graph_objs.Scatter3d( |
| 627 | + x=Xn, y=Yn, z=Zn, mode='markers', |
| 628 | + name='actors', marker=marker, |
| 629 | + hoverinfo='text' |
| 630 | + )) |
| 631 | + |
| 632 | + axis = dict( |
| 633 | + showbackground=False, |
| 634 | + showline=False, |
| 635 | + zeroline=False, |
| 636 | + showgrid=False, |
| 637 | + showticklabels=False, |
| 638 | + title='', |
| 639 | + ) |
| 640 | + |
| 641 | + layout = plotly.graph_objs.Layout( |
| 642 | + showlegend=False, |
| 643 | + scene=dict(xaxis=axis, yaxis=axis, zaxis=axis), |
| 644 | + margin=dict(t=100), |
| 645 | + hovermode='closest') |
| 646 | + |
| 647 | + fig = plotly.graph_objs.Figure(data=plots, layout=layout) |
| 648 | + |
| 649 | + return plotly.offline.iplot(fig) |
| 650 | + |
588 | 651 | def _get_data(self):
|
589 | 652 | return self.data
|
590 | 653 |
|
|
0 commit comments