@@ -16,15 +16,24 @@ class is written.
1616import numpy as np
1717from monty .dev import deprecated
1818
19+ from pymatgen .core import Lattice
20+ from pymatgen .core .units import Ha_to_eV
21+ from pymatgen .electronic_structure .bandstructure import BandStructure
22+ from pymatgen .electronic_structure .core import Spin
1923from pymatgen .io .jdftx ._output_utils import (
2024 _get_nbands_from_bandfile_filepath ,
2125 _get_orb_label_list ,
26+ _get_u_to_oa_map ,
27+ _parse_kptsfrom_bandprojections_file ,
2228 get_proj_tju_from_file ,
29+ orb_ref_list ,
2330 read_outfile_slices ,
2431)
2532from pymatgen .io .jdftx .jdftxoutfileslice import JDFTXOutfileSlice
2633
2734if TYPE_CHECKING :
35+ from numpy .typing import NDArray
36+
2837 from pymatgen .core .structure import Structure
2938 from pymatgen .core .trajectory import Trajectory
3039 from pymatgen .io .jdftx .inputs import JDFTXInfile
@@ -68,6 +77,8 @@ class is written.
6877implemented_store_vars = (
6978 "bandProjections" ,
7079 "eigenvals" ,
80+ "kpts" ,
81+ "bandstructure" ,
7182)
7283
7384
@@ -104,17 +115,22 @@ class JDFTXOutputs:
104115 0-based index will appear mimicking a principle quantum number (ie "0px" for first shell and "1px" for
105116 second shell). The actual principal quantum number is not stored in the JDFTx output files and must be
106117 inferred by the user.
118+ kpts (list[np.ndarray]): A list of the k-points used in the calculation. Each k-point is a 3D numpy array.
119+ wk_list (list[np.ndarray]): A list of the weights for the k-points used in the calculation.
107120 """
108121
109122 calc_dir : str | Path = field (init = True )
110123 outfile_name : str | Path | None = field (init = True )
111124 store_vars : list [str ] = field (default_factory = list , init = True )
112125 paths : dict [str , Path ] = field (init = False )
113126 outfile : JDFTXOutfile = field (init = False )
114- bandProjections : np .ndarray | None = field (init = False )
115- eigenvals : np .ndarray | None = field (init = False )
127+ bandProjections : NDArray | None = field (init = False )
128+ eigenvals : NDArray | None = field (init = False )
129+ kpts : list [NDArray ] | None = field (init = False )
130+ wk_list : list [NDArray ] | None = field (init = False )
116131 # Misc metadata for interacting with the data
117132 orb_label_list : tuple [str , ...] | None = field (init = False )
133+ bandstructure : BandStructure | None = field (init = False )
118134
119135 @classmethod
120136 def from_calc_dir (
@@ -136,13 +152,21 @@ def from_calc_dir(
136152 Returns:
137153 JDFTXOutputs: The JDFTXOutputs object.
138154 """
139- if store_vars is None :
140- store_vars = []
155+ store_vars = cls ._check_store_vars (store_vars )
141156 return cls (calc_dir = Path (calc_dir ), store_vars = store_vars , outfile_name = outfile_name )
142157
143158 def __post_init__ (self ):
144159 self ._init_paths ()
145160 self ._store_vars ()
161+ self ._init_bandstructure ()
162+
163+ def _check_store_vars (store_vars : list [str ] | None ) -> list [str ]:
164+ if store_vars is None :
165+ return []
166+ if "bandstructure" in store_vars :
167+ store_vars += ["kpts" , "eigenvals" ]
168+ store_vars .pop (store_vars .index ("bandstructure" ))
169+ return list (set (store_vars ))
146170
147171 def _init_paths (self ):
148172 self .paths = {}
@@ -208,7 +232,101 @@ def _store_eigenvals(self):
208232 if "eigenvals" in self .paths :
209233 self .eigenvals = np .fromfile (self .paths ["eigenvals" ])
210234 nstates = int (len (self .eigenvals ) / self .outfile .nbands )
211- self .eigenvals = self .eigenvals .reshape (nstates , self .outfile .nbands )
235+ self .eigenvals = self .eigenvals .reshape (nstates , self .outfile .nbands ) * Ha_to_eV
236+
237+ def _check_kpts (self ):
238+ if "bandProjections" in self .paths and self .paths ["bandProjections" ].exists ():
239+ # TODO: Write kpt file inconsistency checking
240+ return
241+ if "kPts" in self .paths and self .paths ["kPts" ].exists ():
242+ raise NotImplementedError ("kPts file parsing not yet implemented." )
243+ raise RuntimeError ("No k-point data found in JDFTx output files." )
244+
245+ def _store_kpts (self ):
246+ if "bandProjections" in self .paths and self .paths ["bandProjections" ].exists ():
247+ wk_list , kpts_list = _parse_kptsfrom_bandprojections_file (self .paths ["bandProjections" ])
248+ self .kpts = kpts_list
249+ self .wk_list = wk_list
250+
251+ def _init_bandstructure (self ):
252+ if True in [v is None for v in [self .kpts , self .eigenvals ]]:
253+ return
254+ kpoints = np .array (self .kpts )
255+ eigenvals = self ._get_pmg_eigenvals ()
256+ projections = None
257+ if self .bandProjections is not None :
258+ projections = self ._get_pmg_projections ()
259+ self .bandstructure = BandStructure (
260+ kpoints ,
261+ eigenvals ,
262+ Lattice (self .outfile .structure .lattice .reciprocal_lattice .matrix ),
263+ self .outfile .mu ,
264+ projections = projections ,
265+ structure = self .outfile .structure ,
266+ )
267+
268+ def _get_lmax (self ) -> tuple [str | None , int | None ]:
269+ """Get the maximum l quantum number and projection array orbital length.
270+
271+ Returns:
272+ tuple[str | None, int | None]: The maximum l quantum number and projection array orbital length.
273+ Both are None if no bandProjections file is available.
274+ """
275+ if self .orb_label_list is None :
276+ return None , None
277+ orbs = [label .split ("(" )[1 ].split (")" )[0 ] for label in self .orb_label_list ]
278+ if orb_ref_list [- 1 ][0 ] in orbs :
279+ return "f" , 16
280+ if orb_ref_list [- 2 ][0 ] in orbs :
281+ return "d" , 9
282+ if orb_ref_list [- 3 ][0 ] in orbs :
283+ return "p" , 4
284+ if orb_ref_list [- 4 ][0 ] in orbs :
285+ return "s" , 1
286+ raise ValueError ("Unrecognized orbital labels in orb_label_list." )
287+
288+ def _get_pmg_eigenvals (self ) -> dict | None :
289+ if self .eigenvals is None :
290+ return None
291+ _e_skj = self .eigenvals .copy ().reshape (self .outfile .nspin , - 1 , self .outfile .nbands )
292+ _e_sjk = np .swapaxes (_e_skj , 1 , 2 )
293+ spins = [Spin .up , Spin .down ]
294+ eigenvals = {}
295+ for i in range (self .outfile .nspin ):
296+ eigenvals [spins [i ]] = _e_sjk [i ]
297+ return eigenvals
298+
299+ def _get_pmg_projections (self ) -> dict | None :
300+ """Return pymatgen-compatible projections dictionary.
301+
302+ Converts the bandProjections array to a pymatgen-compatible dictionary
303+
304+ Returns:
305+ dict | None:
306+ """
307+ lmax , norbmax = self ._get_lmax ()
308+ if norbmax is None :
309+ return None
310+ if self .orb_label_list is None :
311+ return None
312+ _proj_tju = self .bandProjections .copy ()
313+ if _proj_tju .dtype is np .complex64 :
314+ # Convert <orb|band(state)> to |<orb|band(state)>|^2
315+ _proj_tju = np .abs (_proj_tju ) ** 2
316+ # Convert to standard datatype - using np.real to suppress warnings
317+ proj_tju = np .array (np .real (_proj_tju ), dtype = float )
318+ proj_skju = proj_tju .reshape ([self .outfile .nspin , - 1 , self .outfile .nbands , len (self .orb_label_list )])
319+ proj_sjku = np .swapaxes (proj_skju , 1 , 2 )
320+ nspin , nbands , nkpt , nproj = proj_sjku .shape
321+ u_to_oa_map = _get_u_to_oa_map (self .paths ["bandProjections" ])
322+ projections = {}
323+ spins = [Spin .up , Spin .down ]
324+ for i in range (nspin ):
325+ projections [spins [i ]] = np .zeros ([nbands , nkpt , norbmax , len (self .outfile .structure )])
326+ # TODO: Consider jitting this loop
327+ for u in range (nproj ):
328+ projections [spins [i ]][:, :, * u_to_oa_map [u ]] += proj_sjku [i , :, :, u ]
329+ return projections
212330
213331
214332_jof_atr_from_last_slice = (
0 commit comments