Skip to content

Commit 10d5cc5

Browse files
authored
Prefer pymatviz interactive plotly version of periodic table heatmap if available (#3180)
* try from pymatviz import ptable_heatmap_plotly in periodic_table_heatmap() * fix tests * snake_case * add keyword pymatviz: bool = True to disable periodic_table_heatmap deprecation warning
1 parent 3085b27 commit 10d5cc5

File tree

3 files changed

+95
-35
lines changed

3 files changed

+95
-35
lines changed

pymatgen/symmetry/kpath.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,7 +1824,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
18241824

18251825
g = np.dot(W.T, W) # just using change of basis matrix rather than
18261826
# Lattice.get_cartesian_coordinates for conciseness
1827-
ginv = np.linalg.inv(g)
1827+
g_inv = np.linalg.inv(g)
18281828
D = np.linalg.det(W)
18291829

18301830
primary_orientation = secondary_orientation = tertiary_orientation = None
@@ -1871,7 +1871,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
18711871
face_center_found = False
18721872
for point in IRBZ_points:
18731873
if point[0] in face_center_inds:
1874-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1874+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
18751875
if not np.allclose(cross, 0, atol=atol):
18761876
rot_boundaries = [cross, -1 * np.dot(op, cross)]
18771877
face_center_found = True
@@ -1880,7 +1880,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
18801880
if not face_center_found:
18811881
print("face center not found")
18821882
for point in IRBZ_points:
1883-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1883+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
18841884
if not np.allclose(cross, 0, atol=atol):
18851885
rot_boundaries = [cross, -1 * np.dot(op, cross)]
18861886
used_axes.append(ax)
@@ -1896,7 +1896,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
18961896
face_center_found = False
18971897
for point in IRBZ_points:
18981898
if point[0] in face_center_inds:
1899-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1899+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19001900
if not np.allclose(cross, 0, atol=atol):
19011901
rot_boundaries = [cross, np.dot(op, cross)]
19021902
face_center_found = True
@@ -1905,7 +1905,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19051905
if not face_center_found:
19061906
print("face center not found")
19071907
for point in IRBZ_points:
1908-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1908+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19091909
if not np.allclose(cross, 0, atol=atol):
19101910
rot_boundaries = [cross, -1 * np.dot(op, cross)]
19111911
used_axes.append(ax)
@@ -1921,7 +1921,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19211921
face_center_found = False
19221922
for point in IRBZ_points:
19231923
if point[0] in face_center_inds:
1924-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1924+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19251925
if not np.allclose(cross, 0, atol=atol):
19261926
rot_boundaries = [cross, -1 * np.dot(op, cross)]
19271927
face_center_found = True
@@ -1930,7 +1930,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19301930
if not face_center_found:
19311931
print("face center not found")
19321932
for point in IRBZ_points:
1933-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1933+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19341934
if not np.allclose(cross, 0, atol=atol):
19351935
rot_boundaries = [cross, -1 * np.dot(op, cross)]
19361936
used_axes.append(ax)
@@ -1946,7 +1946,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19461946
face_center_found = False
19471947
for point in IRBZ_points:
19481948
if point[0] in face_center_inds:
1949-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1949+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19501950
if not np.allclose(cross, 0, atol=atol):
19511951
rot_boundaries = [
19521952
cross,
@@ -1958,7 +1958,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19581958
if not face_center_found:
19591959
print("face center not found")
19601960
for point in IRBZ_points:
1961-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1961+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19621962
if not np.allclose(cross, 0, atol=atol):
19631963
rot_boundaries = [cross, -1 * np.dot(op, cross)]
19641964
used_axes.append(ax)
@@ -1974,7 +1974,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19741974
face_center_found = False
19751975
for point in IRBZ_points:
19761976
if point[0] in face_center_inds:
1977-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1977+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19781978
if not np.allclose(cross, 0, atol=atol):
19791979
rot_boundaries = [cross, -1 * np.dot(op, cross)]
19801980
face_center_found = True
@@ -1983,7 +1983,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19831983
if not face_center_found:
19841984
print("face center not found")
19851985
for point in IRBZ_points:
1986-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
1986+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
19871987
if not np.allclose(cross, 0, atol=atol):
19881988
rot_boundaries = [cross, -1 * np.dot(op, cross)]
19891989
used_axes.append(ax)
@@ -1999,7 +1999,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
19991999
face_center_found = False
20002000
for point in IRBZ_points:
20012001
if point[0] in face_center_inds:
2002-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
2002+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
20032003
if not np.allclose(cross, 0, atol=atol):
20042004
rot_boundaries = [cross, -1 * np.dot(op, cross)]
20052005
face_center_found = True
@@ -2008,7 +2008,7 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol):
20082008
if not face_center_found:
20092009
print("face center not found")
20102010
for point in IRBZ_points:
2011-
cross = D * np.dot(ginv, np.cross(ax, point[1]))
2011+
cross = D * np.dot(g_inv, np.cross(ax, point[1]))
20122012
if not np.allclose(cross, 0, atol=atol):
20132013
rot_boundaries = [cross, -1 * np.dot(op, cross)]
20142014
used_axes.append(ax)

pymatgen/util/plotting.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _decide_fontcolor(rgba: tuple) -> Literal["black", "white"]:
180180

181181

182182
def periodic_table_heatmap(
183-
elemental_data,
183+
elemental_data=None,
184184
cbar_label="",
185185
cbar_label_size=14,
186186
show_plot=False,
@@ -193,39 +193,90 @@ def periodic_table_heatmap(
193193
symbol_fontsize=14,
194194
max_row=9,
195195
readable_fontcolor=False,
196+
pymatviz: bool = True,
197+
**kwargs,
196198
):
197199
"""
198200
A static method that generates a heat map overlaid on a periodic table.
199201
200202
Args:
201-
elemental_data (dict): A dictionary with the element as a key and a
203+
elemental_data (dict): A dictionary with the element as a key and a
202204
value assigned to it, e.g. surface energy and frequency, etc.
203205
Elements missing in the elemental_data will be grey by default
204206
in the final table elemental_data={"Fe": 4.2, "O": 5.0}.
205-
cbar_label (str): Label of the color bar. Default is "".
206-
cbar_label_size (float): Font size for the color bar label. Default is 14.
207-
cmap_range (tuple): Minimum and maximum value of the color map scale.
207+
cbar_label (str): Label of the color bar. Default is "".
208+
cbar_label_size (float): Font size for the color bar label. Default is 14.
209+
cmap_range (tuple): Minimum and maximum value of the color map scale.
208210
If None, the color map will automatically scale to the range of the
209211
data.
210-
show_plot (bool): Whether to show the heatmap. Default is False.
211-
value_format (str): Formatting string to show values. If None, no value
212+
show_plot (bool): Whether to show the heatmap. Default is False.
213+
value_format (str): Formatting string to show values. If None, no value
212214
is shown. Example: "%.4f" shows float to four decimals.
213-
value_fontsize (float): Font size for values. Default is 10.
214-
symbol_fontsize (float): Font size for element symbols. Default is 14.
215-
cmap (str): Color scheme of the heatmap. Default is 'YlOrRd'.
215+
value_fontsize (float): Font size for values. Default is 10.
216+
symbol_fontsize (float): Font size for element symbols. Default is 14.
217+
cmap (str): Color scheme of the heatmap. Default is 'YlOrRd'.
216218
Refer to the matplotlib documentation for other options.
217-
blank_color (str): Color assigned for the missing elements in
219+
blank_color (str): Color assigned for the missing elements in
218220
elemental_data. Default is "grey".
219-
edge_color (str): Color assigned for the edge of elements in the
221+
edge_color (str): Color assigned for the edge of elements in the
220222
periodic table. Default is "white".
221-
max_row (int): Maximum number of rows of the periodic table to be
223+
max_row (int): Maximum number of rows of the periodic table to be
222224
shown. Default is 9, which means the periodic table heat map covers
223225
the standard 7 rows of the periodic table + 2 rows for the lanthanides
224226
and actinides. Use a value of max_row = 7 to exclude the lanthanides and
225227
actinides.
226-
readable_fontcolor (bool): Whether to use readable font color depending
228+
readable_fontcolor (bool): Whether to use readable font color depending
227229
on background color. Default is False.
230+
pymatviz (bool): Whether to use pymatviz to generate the heatmap. Defaults to True.
231+
See https://github.com/janosh/pymatviz.
232+
kwargs: Passed to pymatviz.ptable_heatmap_plotly
228233
"""
234+
if pymatviz:
235+
try:
236+
from pymatviz import ptable_heatmap_plotly
237+
238+
if elemental_data:
239+
kwargs.setdefault("elem_values", elemental_data)
240+
print('elemental_data is deprecated, use elem_values={"Fe": 4.2, "O": 5.0} instead')
241+
if cbar_label:
242+
kwargs.setdefault("color_bar", {}).setdefault("title", cbar_label)
243+
print('cbar_label is deprecated, use color_bar={"title": cbar_label} instead')
244+
if cbar_label_size != 14:
245+
kwargs.setdefault("color_bar", {}).setdefault("titlefont", {}).setdefault("size", cbar_label_size)
246+
print('cbar_label_size is deprecated, use color_bar={"titlefont": {"size": cbar_label_size}} instead')
247+
if cmap:
248+
kwargs.setdefault("colorscale", cmap)
249+
print("cmap is deprecated, use colorscale=cmap instead")
250+
if cmap_range:
251+
kwargs.setdefault("cscale_range", cmap_range)
252+
print("cmap_range is deprecated, use cscale_range instead")
253+
if value_format:
254+
kwargs.setdefault("precision", value_format)
255+
print("value_format is deprecated, use precision instead")
256+
if blank_color != "grey":
257+
print("blank_color is deprecated")
258+
if edge_color != "white":
259+
print("edge_color is deprecated")
260+
if symbol_fontsize != 14:
261+
print("symbol_fontsize is deprecated, use font_size instead")
262+
kwargs.setdefault("font_size", symbol_fontsize)
263+
if value_fontsize != 10:
264+
print("value_fontsize is deprecated, use font_size instead")
265+
kwargs.setdefault("font_size", value_fontsize)
266+
if max_row != 9:
267+
print("max_row is deprecated, use max_row instead")
268+
if readable_fontcolor:
269+
print("readable_fontcolor is deprecated, use font_colors instead, e.g. ('black', 'white')")
270+
271+
return ptable_heatmap_plotly(**kwargs)
272+
except ImportError:
273+
print(
274+
"You're using a deprecated version of periodic_table_heatmap(). Consider `pip install pymatviz` which "
275+
"offers an interactive plotly periodic table heatmap. You can keep calling this same function from "
276+
"pymatgen. Some of the arguments have changed which you'll be warned about. "
277+
"To disable this warning, pass pymatviz=False."
278+
)
279+
229280
# Convert primitive_elemental data in the form of numpy array for plotting.
230281
if cmap_range is not None:
231282
max_val = cmap_range[1]
@@ -286,7 +337,7 @@ def periodic_table_heatmap(
286337
ax.axis("off")
287338
ax.invert_yaxis()
288339

289-
# Set the scalermap for fontcolor
340+
# Set the scalarmap for fontcolor
290341
norm = colors.Normalize(vmin=min_val, vmax=max_val)
291342
scalar_cmap = cm.ScalarMappable(norm=norm, cmap=cmap)
292343

@@ -335,20 +386,20 @@ def format_formula(formula):
335386
"""
336387
formatted_formula = ""
337388
number_format = ""
338-
for i, s in enumerate(formula):
339-
if s.isdigit():
389+
for idx, char in enumerate(formula):
390+
if char.isdigit():
340391
if not number_format:
341392
number_format = "_{"
342-
number_format += s
343-
if i == len(formula) - 1:
393+
number_format += char
394+
if idx == len(formula) - 1:
344395
number_format += "}"
345396
formatted_formula += number_format
346397
else:
347398
if number_format:
348399
number_format += "}"
349400
formatted_formula += number_format
350401
number_format = ""
351-
formatted_formula += s
402+
formatted_formula += char
352403

353404
return f"${formatted_formula}$"
354405

pymatgen/util/tests/test_plotting.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,21 @@
55
from pymatgen.util.plotting import periodic_table_heatmap, van_arkel_triangle
66
from pymatgen.util.testing import PymatgenTest
77

8+
try:
9+
import pymatviz
10+
from plotly.graph_objects import Figure
11+
except ImportError:
12+
pymatviz = None
13+
814

915
class FuncTestCase(PymatgenTest):
1016
def test_plot_periodic_heatmap(self):
1117
random_data = {"Te": 0.11083, "Au": 0.75756, "Th": 1.24758, "Ni": -2.0354}
1218
ret_val = periodic_table_heatmap(random_data)
13-
assert ret_val is plt
19+
if pymatviz:
20+
assert isinstance(ret_val, Figure)
21+
else:
22+
assert ret_val is plt
1423

1524
# Test all keywords
1625
periodic_table_heatmap(
@@ -21,7 +30,7 @@ def test_plot_periodic_heatmap(self):
2130
cmap_range=[0, 1],
2231
cbar_label="Hello World",
2332
blank_color="white",
24-
value_format="%.4f",
33+
value_format=".4f",
2534
edge_color="black",
2635
value_fontsize=12,
2736
symbol_fontsize=18,

0 commit comments

Comments
 (0)