44from pathlib import Path
55from typing import Optional , Union
66
7- import matplotlib .pyplot as plt
8- import numpy as np
97import phonopy
108from emmet .core .structure import StructureMetadata
11- from matplotlib import colors
12- from matplotlib .colors import LinearSegmentedColormap
139from phonopy .api_gruneisen import PhonopyGruneisen
1410from phonopy .phonon .band_structure import get_band_qpoints_and_path_connections
1511from pydantic import BaseModel , Field
2420 GruneisenParameter ,
2521 GruneisenPhononBandStructureSymmLine ,
2622)
27- from pymatgen .phonon .plotter import (
28- GruneisenPhononBSPlotter ,
29- GruneisenPlotter ,
30- freq_units ,
31- )
32- from pymatgen .util .plotting import pretty_plot
23+ from pymatgen .phonon .plotter import GruneisenPhononBSPlotter , GruneisenPlotter
3324from typing_extensions import Self
3425
3526from atomate2 .common .schemas .phonons import PhononBSDOSDoc
@@ -164,7 +155,7 @@ def from_phonon_yamls(
164155 mesh = mesh ,
165156 shift = compute_gruneisen_param_kwargs .get ("shift" ),
166157 is_gamma_center = compute_gruneisen_param_kwargs .get (
167- "is_gamma_center" , True
158+ "is_gamma_center" , False
168159 ),
169160 is_time_reversal = compute_gruneisen_param_kwargs .get (
170161 "is_time_reversal" , True
@@ -184,7 +175,7 @@ def from_phonon_yamls(
184175 mesh = kpoint .kpts [0 ],
185176 shift = compute_gruneisen_param_kwargs .get ("shift" ),
186177 is_gamma_center = compute_gruneisen_param_kwargs .get (
187- "is_gamma_center" , True
178+ "is_gamma_center" , False
188179 ),
189180 is_time_reversal = compute_gruneisen_param_kwargs .get (
190181 "is_time_reversal" , True
@@ -228,9 +219,14 @@ def from_phonon_yamls(
228219 labels_dict = kpath_dict ,
229220 )
230221 gp_bs_plot = GruneisenPhononBSPlotter (bs = gruneisen_band_structure )
231- GruneisenParameterDocument .get_gruneisen_weighted_bandstructure (
232- gruneisen_band_symline_plotter = gp_bs_plot ,
233- save_fig = True ,
222+
223+ gruneisen_bs_plot = compute_gruneisen_param_kwargs .get (
224+ "gruneisen_bs" , "gruneisen_band.pdf"
225+ )
226+ gp_bs_plot .save_plot_gs (
227+ filename = gruneisen_bs_plot ,
228+ plot_ph_bs_with_gruneisen = True ,
229+ img_format = compute_gruneisen_param_kwargs .get ("img_format" , "pdf" ),
234230 ** compute_gruneisen_param_kwargs ,
235231 )
236232 gruneisen_parameter_inputs = {
@@ -261,82 +257,3 @@ def from_phonon_yamls(
261257 gruneisen_band_structure = gruneisen_band_structure ,
262258 derived_properties = derived_properties ,
263259 )
264-
265- @staticmethod
266- def get_gruneisen_weighted_bandstructure (
267- gruneisen_band_symline_plotter : GruneisenPhononBSPlotter ,
268- save_fig : bool = True ,
269- ** kwargs ,
270- ) -> None :
271- """Save a phonon band structure weighted with Grueneisen parameters.
272-
273- Parameters
274- ----------
275- gruneisen_band_symline_plotter: GruneisenPhononBSPlotter
276- pymatgen GruneisenPhononBSPlotter obj
277- save_fig: bool
278- bool to save plots
279- kwargs: dict
280- keyword arguments to adjust plotter
281-
282- Returns
283- -------
284- None
285- """
286- u = freq_units (kwargs .get ("units" , "THz" ))
287- ax = pretty_plot (12 , 8 )
288- gruneisen_band_symline_plotter ._make_ticks (ax ) # noqa: SLF001
289-
290- # plot y=0 line
291- ax .axhline (0 , linewidth = 1 , color = "black" )
292-
293- # Create custom colormap (default is red to blue)
294- cmap = LinearSegmentedColormap .from_list (
295- "mycmap" , kwargs .get ("mycmap" , ["red" , "blue" ])
296- )
297-
298- data = gruneisen_band_symline_plotter .bs_plot_data ()
299-
300- # extract min and max Grüneisen parameter values
301- max_gruneisen = np .array (data ["gruneisen" ]).max ()
302- min_gruneisen = np .array (data ["gruneisen" ]).min ()
303-
304- # LogNormalize colormap based on the min and max Grüneisen parameter values
305- norm = colors .SymLogNorm (
306- vmin = min_gruneisen ,
307- vmax = max_gruneisen ,
308- linthresh = 1e-2 ,
309- linscale = 1 ,
310- )
311-
312- for (dists_inx , dists ), (_ , freqs ) in zip (
313- enumerate (data ["distances" ]), enumerate (data ["frequency" ]), strict = True
314- ):
315- for band_idx in range (gruneisen_band_symline_plotter .n_bands ):
316- ys = [freqs [band_idx ][j ] * u .factor for j in range (len (dists ))]
317- ys_gru = [
318- data ["gruneisen" ][dists_inx ][band_idx ][idx ]
319- for idx in range (len (data ["distances" ][dists_inx ]))
320- ]
321- sc = ax .scatter (
322- dists , ys , c = ys_gru , cmap = cmap , norm = norm , marker = "o" , s = 1
323- )
324-
325- # Main X and Y Labels
326- ax .set_xlabel (r"$\mathrm{Wave\ Vector}$" , fontsize = 30 )
327- units = kwargs .get ("units" , "THz" )
328- ax .set_ylabel (f"Frequencies ({ units } )" , fontsize = 30 )
329- # X range (K)
330- # last distance point
331- x_max = data ["distances" ][- 1 ][- 1 ]
332- ax .set_xlim (0 , x_max )
333-
334- cbar = plt .colorbar (sc , ax = ax )
335- cbar .set_label (r"$\gamma \ \mathrm{(logarithmized)}$" , fontsize = 30 )
336- plt .tight_layout ()
337- gruneisen_band_plot = kwargs .get ("gruneisen_bs" , "gruneisen_band.pdf" )
338- if save_fig :
339- plt .savefig (fname = gruneisen_band_plot )
340- plt .close ()
341- else :
342- plt .close ()
0 commit comments