Skip to content

Commit 33beea2

Browse files
lbluquejanosh
andauthored
Fix pdplotter.show with matplotlib backend (#3493)
* BUG: fix pdplotter.show with matplotlib backend Signed-off-by: lbluque <[email protected]> * add better tests for PDPlotter.get_plot() and show() --------- Signed-off-by: lbluque <[email protected]> Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 9518085 commit 33beea2

File tree

3 files changed

+35
-12
lines changed

3 files changed

+35
-12
lines changed

pymatgen/analysis/phase_diagram.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,11 @@ def show(self, *args, **kwargs) -> None:
22762276
*args: Passed to get_plot.
22772277
**kwargs: Passed to get_plot.
22782278
"""
2279-
self.get_plot(*args, **kwargs).show()
2279+
plot = self.get_plot(*args, **kwargs)
2280+
if self.backend == "matplotlib":
2281+
plot.get_figure().show()
2282+
else:
2283+
plot.show()
22802284

22812285
def write_image(self, stream: str | StringIO, image_format: str = "svg", **kwargs) -> None:
22822286
"""

tests/analysis/test_phase_diagram.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
import collections
44
import os
55
import unittest
6+
import unittest.mock
67
from numbers import Number
78

9+
import matplotlib.pyplot as plt
810
import numpy as np
11+
import plotly.graph_objects as go
912
import pytest
1013
from monty.serialization import dumpfn, loadfn
1114
from numpy.testing import assert_allclose
@@ -892,7 +895,9 @@ def test_plot_pd_with_no_unstable(self):
892895
pd_entries = [PDEntry(comp, 0) for comp in ["Li", "Co", "O"]]
893896
pd = PhaseDiagram(pd_entries)
894897
plotter = PDPlotter(pd, backend="plotly", show_unstable=False)
895-
plotter.get_plot()
898+
ax = plotter.get_plot()
899+
assert isinstance(ax, go.Figure)
900+
assert len(ax.data) == 4
896901

897902
def test_pd_plot_data(self):
898903
lines, labels, unstable_entries = self.plotter_ternary_mpl.pd_plot_data
@@ -909,22 +914,36 @@ def test_pd_plot_data(self):
909914
assert len(lines) == 3
910915
assert len(labels) == len(self.pd_binary.stable_entries)
911916

912-
def test_mpl_plots(self):
917+
def test_matplotlib_plots(self):
913918
# Some very basic ("non")-tests. Just to make sure the methods are callable.
914-
self.plotter_binary_mpl.get_plot()
915-
self.plotter_ternary_mpl.get_plot()
916-
self.plotter_quaternary_mpl.get_plot()
919+
for plotter in (self.plotter_binary_mpl, self.plotter_ternary_mpl, self.plotter_quaternary_mpl):
920+
ax = plotter.get_plot()
921+
assert isinstance(ax, plt.Axes)
922+
917923
self.plotter_ternary_mpl.get_contour_pd_plot()
918924
self.plotter_ternary_mpl.get_chempot_range_map_plot([Element("Li"), Element("O")])
919925
self.plotter_ternary_mpl.plot_element_profile(Element("O"), Composition("Li2O"))
920926

927+
# test show()
928+
assert self.plotter_ternary_mpl.show() is None
929+
921930
def test_plotly_plots(self):
922931
# Also very basic tests. Ensures callability and 2D vs 3D properties.
923-
self.plotter_unary_plotly.get_plot()
924-
self.plotter_binary_plotly.get_plot()
925-
self.plotter_ternary_plotly_2d.get_plot()
926-
self.plotter_ternary_plotly_3d.get_plot()
927-
self.plotter_quaternary_plotly.get_plot()
932+
for plotter in (
933+
self.plotter_unary_plotly,
934+
self.plotter_binary_plotly,
935+
self.plotter_ternary_plotly_2d,
936+
self.plotter_ternary_plotly_3d,
937+
self.plotter_quaternary_plotly,
938+
):
939+
fig = plotter.get_plot()
940+
assert isinstance(fig, go.Figure)
941+
942+
# test show()
943+
# suppress default plotly behavior of opening figure in browser by patching plotly.io.show to noop
944+
with unittest.mock.patch("plotly.io.show") as mock_show:
945+
assert self.plotter_ternary_plotly_2d.show() is None
946+
mock_show.assert_called_once()
928947

929948

930949
class TestUtilityFunction(unittest.TestCase):

tests/files/.pytest-split-durations

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@
461461
"tests/analysis/test_phase_diagram.py::TestPDEntry::test_read_csv": 0.004186584032140672,
462462
"tests/analysis/test_phase_diagram.py::TestPDEntry::test_str": 0.00042208394734188914,
463463
"tests/analysis/test_phase_diagram.py::TestPDEntry::test_as_from_dict": 0.00037904095370322466,
464-
"tests/analysis/test_phase_diagram.py::TestPDPlotter::test_mpl_plots": 1.1096638339804485,
464+
"tests/analysis/test_phase_diagram.py::TestPDPlotter::test_matplotlib_plots": 1.1096638339804485,
465465
"tests/analysis/test_phase_diagram.py::TestPDPlotter::test_pd_plot_data": 0.05885787500301376,
466466
"tests/analysis/test_phase_diagram.py::TestPDPlotter::test_plot_pd_with_no_unstable": 0.05304033396532759,
467467
"tests/analysis/test_phase_diagram.py::TestPDPlotter::test_plotly_plots": 0.2711315419874154,

0 commit comments

Comments
 (0)