Skip to content

Commit 1d25d7f

Browse files
feat: Add support for plotly backend
1 parent 4b0af2d commit 1d25d7f

File tree

11 files changed

+295
-21
lines changed

11 files changed

+295
-21
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from ansys.tools.visualization_interface.backends.plotly.plotly_interface import PlotlyBackend
2+
from ansys.tools.visualization_interface.types import MeshObjectPlot
3+
from ansys.tools.visualization_interface import Plotter
4+
import pyvista as pv
5+
from plotly.graph_objects import Mesh3d
6+
7+
# Create a plotter with the Plotly backend
8+
pl = Plotter(backend=PlotlyBackend())
9+
10+
# Create a PyVista mesh
11+
mesh = pv.Sphere()
12+
13+
# Plot the mesh
14+
pl.plot(mesh)
15+
16+
# Display the plotter
17+
pl.show()
18+
19+
# Now create a custom object
20+
class CustomObject:
21+
def __init__(self):
22+
self.name = "CustomObject"
23+
self.mesh = pv.Cube(center=(1, 1, 0))
24+
25+
def get_mesh(self):
26+
return self.mesh
27+
28+
def name(self):
29+
return self.name
30+
31+
32+
33+
# Create a custom object
34+
custom_cube = CustomObject()
35+
custom_cube.name = "CustomCube"
36+
37+
# Create a MeshObjectPlot instance
38+
mesh_object_cube = MeshObjectPlot(custom_cube, custom_cube.get_mesh())
39+
40+
# Plot the custom mesh object
41+
pl.plot(mesh_object_cube)
42+
43+
# Since Plotly is a web-based visualization, we can show the plot again to include the new object
44+
pl.show()
45+
46+
# Add a Plotly Mesh3d object directly
47+
custom_mesh3d = Mesh3d(
48+
x=[0, 1, 2],
49+
y=[0, 1, 0],
50+
z=[0, 0, 1],
51+
i=[0],
52+
j=[1],
53+
k=[2],
54+
color='lightblue',
55+
opacity=0.50
56+
)
57+
pl.plot(custom_mesh3d)
58+
pl.show()
59+
60+
# Show other plotly objects like Scatter3d
61+
from plotly.graph_objects import Scatter3d
62+
63+
scatter = Scatter3d(
64+
x=[0, 1, 2],
65+
y=[0, 1, 0],
66+
z=[0, 0, 1],
67+
mode='markers',
68+
marker=dict(size=5, color='red')
69+
)
70+
pl.plot(scatter)
71+
pl.show()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pyvistaqt = [
3939

4040
plotly = [
4141
"plotly >= 5.15.0,<6",
42+
"kaleido >= 1.1.0,<2",
4243
]
4344
tests = [
4445
"pytest==8.4.2",

src/ansys/tools/visualization_interface/backends/plotly/plotly_interface.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Union, Iterable, Any
77

88

9-
class PlotlyInterface(BaseBackend):
9+
class PlotlyBackend(BaseBackend):
1010
"""Plotly interface for visualization."""
1111

1212
def __init__(self, **kwargs):
@@ -17,16 +17,20 @@ def _pv_to_mesh3d(self, pv_mesh: PolyData) -> go.Mesh3d:
1717
points = pv_mesh.points
1818
x, y, z = points[:, 0], points[:, 1], points[:, 2]
1919

20-
faces = pv_mesh.faces.reshape((-1, 4)) # First number in each row is the number of points in the face (3 for triangles)
20+
# Convert mesh to triangular mesh if needed, since Plotly only supports triangular faces
21+
triangulated_mesh = pv_mesh.triangulate()
22+
23+
# Extract triangular faces
24+
faces = triangulated_mesh.faces.reshape((-1, 4)) # Now we know all faces are triangular (3 vertices + count)
2125
i, j, k = faces[:, 1], faces[:, 2], faces[:, 3]
2226

2327
return go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k)
2428
@property
2529
def layout(self) -> Any:
2630
"""Get the current layout of the Plotly figure."""
2731
return self._fig.layout
28-
29-
@setters.layout
32+
33+
@layout.setter
3034
def layout(self, new_layout: Any):
3135
"""Set a new layout for the Plotly figure."""
3236
self._fig.update_layout(new_layout)
@@ -54,6 +58,22 @@ def plot(self, plottable_object: Union[PolyData, MeshObjectPlot, go.Mesh3d], **p
5458
except Exception:
5559
raise TypeError("Unsupported plottable_object type for PlotlyInterface.")
5660

57-
def show(self):
61+
def show(self,
62+
plottable_object=None,
63+
screenshot: str = None,
64+
name_filter=None,
65+
**kwargs):
5866
"""Render the Plotly scene."""
59-
self._fig.show()
67+
if plottable_object is not None:
68+
self.plot(plottable_object)
69+
70+
# Only show in browser if no screenshot is being taken
71+
if not screenshot:
72+
self._fig.show(**kwargs)
73+
74+
if screenshot:
75+
screenshot_str = str(screenshot)
76+
if screenshot_str.endswith('.html'):
77+
self._fig.write_html(screenshot_str)
78+
else:
79+
self._fig.write_image(screenshot_str)

src/ansys/tools/visualization_interface/plotter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ def backend(self):
4949
"""Return the base plotter object."""
5050
return self._backend
5151

52+
def plot_iter(self, plotting_list: List, **plotting_options):
53+
"""Plots multiple objects using the specified backend.
54+
55+
Parameters
56+
----------
57+
plotting_list : List
58+
List of objects to plot.
59+
plotting_options : dict
60+
Additional plotting options.
61+
"""
62+
self._backend.plot_iter(plotting_list=plotting_list, **plotting_options)
63+
5264
def plot(self, plottable_object: Any, **plotting_options):
5365
"""Plots an object using the specified backend.
5466

src/ansys/tools/visualization_interface/types/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222
"""Provides custom types."""
23+
from .mesh_object_plot import MeshObjectPlot

tests/conftest.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,51 @@
2323
import os
2424

2525
import pytest
26+
from PIL import Image, ImageChops
27+
from pathlib import Path
2628

2729
os.environ.setdefault("PYANSYS_VISUALIZER_TESTMODE", "true")
2830

29-
@pytest.fixture(autouse=True)
30-
def wrapped_verify_image_cache(verify_image_cache):
31-
"""Wraps the verify_image_cache fixture to ensure that the image cache is verified.
32-
33-
Parameters
34-
----------
35-
verify_image_cache : fixture
36-
Fixture to wrap.
37-
38-
Returns
39-
-------
40-
fixture
41-
Wrapped fixture.
42-
"""
43-
return verify_image_cache
31+
@pytest.fixture
32+
def image_compare():
33+
"""Fixture to compare images."""
34+
def _compare_images(generated_image_path):
35+
"""Compare two images and optionally save the difference image.
36+
37+
Parameters
38+
----------
39+
generated_image_path : str
40+
Path to the generated image.
41+
baseline_image_path : str
42+
Path to the baseline image.
43+
diff_image_path : str, optional
44+
Path to save the difference image if images do not match.
45+
46+
Returns
47+
-------
48+
bool
49+
True if images match, False otherwise.
50+
"""
51+
# Get the name of the image file using Pathlib
52+
image_name = Path(generated_image_path).name
53+
54+
# Define the baseline image path
55+
baseline_image_path = Path(__file__).parent / "_image_cache" / image_name
56+
57+
img1 = Image.open(generated_image_path).convert("RGB")
58+
try:
59+
img2 = Image.open(baseline_image_path).convert("RGB")
60+
except FileNotFoundError:
61+
# copy generated image to baseline location if baseline does not exist
62+
img1.save(baseline_image_path)
63+
img2 = Image.open(baseline_image_path).convert("RGB")
64+
65+
# Compute the difference between the two images
66+
diff = ImageChops.difference(img1, img2)
67+
68+
if diff.getbbox() is None:
69+
return True
70+
else:
71+
return False
72+
73+
return _compare_images

tests/test_generic_plotter.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ def __init__(self, name) -> None:
4242
self.name = name
4343

4444

45+
@pytest.fixture(autouse=True)
46+
def wrapped_verify_image_cache(verify_image_cache):
47+
"""Wraps the verify_image_cache fixture to ensure that the image cache is verified.
48+
49+
Parameters
50+
----------
51+
verify_image_cache : fixture
52+
Fixture to wrap.
53+
54+
Returns
55+
-------
56+
fixture
57+
Wrapped fixture.
58+
"""
59+
return verify_image_cache
60+
61+
4562
def test_plotter_add_pd():
4663
"""Adds polydata to the plotter."""
4764
pl = Plotter()

0 commit comments

Comments
 (0)