3
3
import collections
4
4
import os
5
5
import unittest
6
+ import unittest .mock
6
7
from numbers import Number
7
8
9
+ import matplotlib .pyplot as plt
8
10
import numpy as np
11
+ import plotly .graph_objects as go
9
12
import pytest
10
13
from monty .serialization import dumpfn , loadfn
11
14
from numpy .testing import assert_allclose
@@ -892,7 +895,9 @@ def test_plot_pd_with_no_unstable(self):
892
895
pd_entries = [PDEntry (comp , 0 ) for comp in ["Li" , "Co" , "O" ]]
893
896
pd = PhaseDiagram (pd_entries )
894
897
plotter = PDPlotter (pd , backend = "plotly" , show_unstable = False )
895
- plotter .get_plot ()
898
+ ax = plotter .get_plot ()
899
+ assert isinstance (ax , go .Figure )
900
+ assert len (ax .data ) == 4
896
901
897
902
def test_pd_plot_data (self ):
898
903
lines , labels , unstable_entries = self .plotter_ternary_mpl .pd_plot_data
@@ -909,22 +914,36 @@ def test_pd_plot_data(self):
909
914
assert len (lines ) == 3
910
915
assert len (labels ) == len (self .pd_binary .stable_entries )
911
916
912
- def test_mpl_plots (self ):
917
+ def test_matplotlib_plots (self ):
913
918
# Some very basic ("non")-tests. Just to make sure the methods are callable.
914
- self .plotter_binary_mpl .get_plot ()
915
- self .plotter_ternary_mpl .get_plot ()
916
- self .plotter_quaternary_mpl .get_plot ()
919
+ for plotter in (self .plotter_binary_mpl , self .plotter_ternary_mpl , self .plotter_quaternary_mpl ):
920
+ ax = plotter .get_plot ()
921
+ assert isinstance (ax , plt .Axes )
922
+
917
923
self .plotter_ternary_mpl .get_contour_pd_plot ()
918
924
self .plotter_ternary_mpl .get_chempot_range_map_plot ([Element ("Li" ), Element ("O" )])
919
925
self .plotter_ternary_mpl .plot_element_profile (Element ("O" ), Composition ("Li2O" ))
920
926
927
+ # test show()
928
+ assert self .plotter_ternary_mpl .show () is None
929
+
921
930
def test_plotly_plots (self ):
922
931
# Also very basic tests. Ensures callability and 2D vs 3D properties.
923
- self .plotter_unary_plotly .get_plot ()
924
- self .plotter_binary_plotly .get_plot ()
925
- self .plotter_ternary_plotly_2d .get_plot ()
926
- self .plotter_ternary_plotly_3d .get_plot ()
927
- self .plotter_quaternary_plotly .get_plot ()
932
+ for plotter in (
933
+ self .plotter_unary_plotly ,
934
+ self .plotter_binary_plotly ,
935
+ self .plotter_ternary_plotly_2d ,
936
+ self .plotter_ternary_plotly_3d ,
937
+ self .plotter_quaternary_plotly ,
938
+ ):
939
+ fig = plotter .get_plot ()
940
+ assert isinstance (fig , go .Figure )
941
+
942
+ # test show()
943
+ # suppress default plotly behavior of opening figure in browser by patching plotly.io.show to noop
944
+ with unittest .mock .patch ("plotly.io.show" ) as mock_show :
945
+ assert self .plotter_ternary_plotly_2d .show () is None
946
+ mock_show .assert_called_once ()
928
947
929
948
930
949
class TestUtilityFunction (unittest .TestCase ):
0 commit comments