Skip to content

Commit ad7a3d2

Browse files
committed
add 'LearnerND.plot_3D' and add an example to the docs
1 parent b563e09 commit ad7a3d2

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

adaptive/learner/learnerND.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .base_learner import BaseLearner
1212

13-
from ..notebook_integration import ensure_holoviews
13+
from ..notebook_integration import ensure_holoviews, ensure_plotly
1414
from .triangulation import (Triangulation, point_in_simplex,
1515
circumsphere, simplex_volume_in_embedding)
1616
from ..utils import restore, cache_latest
@@ -585,6 +585,69 @@ def plot_slice(self, cut_mapping, n=None):
585585
else:
586586
raise ValueError("Only 1 or 2-dimensional plots can be generated.")
587587

588+
def plot_3D(self, with_triangulation=False):
589+
"""Plot the learner's data in 3D using plotly.
590+
591+
Parameters
592+
----------
593+
with_triangulation : bool, default: False
594+
Add the verticices to the plot.
595+
596+
Returns
597+
-------
598+
plot : plotly.offline.iplot object
599+
The 3D plot of ``learner.data``.
600+
"""
601+
plotly = ensure_plotly()
602+
603+
plots = []
604+
605+
vertices = self.tri.vertices
606+
if with_triangulation:
607+
Xe, Ye, Ze = [], [], []
608+
for simplex in self.tri.simplices:
609+
for s in itertools.combinations(simplex, 2):
610+
Xe += [vertices[i][0] for i in s] + [None]
611+
Ye += [vertices[i][1] for i in s] + [None]
612+
Ze += [vertices[i][2] for i in s] + [None]
613+
614+
plots.append(plotly.graph_objs.Scatter3d(
615+
x=Xe, y=Ye, z=Ze, mode='lines',
616+
line=dict(color='rgb(125,125,125)', width=1),
617+
hoverinfo='none'
618+
))
619+
620+
Xn, Yn, Zn = zip(*vertices)
621+
colors = [self.data[p] for p in self.tri.vertices]
622+
marker = dict(symbol='circle', size=3, color=colors,
623+
colorscale='Viridis',
624+
line=dict(color='rgb(50,50,50)', width=0.5))
625+
626+
plots.append(plotly.graph_objs.Scatter3d(
627+
x=Xn, y=Yn, z=Zn, mode='markers',
628+
name='actors', marker=marker,
629+
hoverinfo='text'
630+
))
631+
632+
axis = dict(
633+
showbackground=False,
634+
showline=False,
635+
zeroline=False,
636+
showgrid=False,
637+
showticklabels=False,
638+
title='',
639+
)
640+
641+
layout = plotly.graph_objs.Layout(
642+
showlegend=False,
643+
scene=dict(xaxis=axis, yaxis=axis, zaxis=axis),
644+
margin=dict(t=100),
645+
hovermode='closest')
646+
647+
fig = plotly.graph_objs.Figure(data=plots, layout=layout)
648+
649+
return plotly.offline.iplot(fig)
650+
588651
def _get_data(self):
589652
return self.data
590653

0 commit comments

Comments
 (0)