@@ -16,15 +16,24 @@ class is written.
16
16
import numpy as np
17
17
from monty .dev import deprecated
18
18
19
+ from pymatgen .core import Lattice
20
+ from pymatgen .core .units import Ha_to_eV
21
+ from pymatgen .electronic_structure .bandstructure import BandStructure
22
+ from pymatgen .electronic_structure .core import Spin
19
23
from pymatgen .io .jdftx ._output_utils import (
20
24
_get_nbands_from_bandfile_filepath ,
21
25
_get_orb_label_list ,
26
+ _get_u_to_oa_map ,
27
+ _parse_kptsfrom_bandprojections_file ,
22
28
get_proj_tju_from_file ,
29
+ orb_ref_list ,
23
30
read_outfile_slices ,
24
31
)
25
32
from pymatgen .io .jdftx .jdftxoutfileslice import JDFTXOutfileSlice
26
33
27
34
if TYPE_CHECKING :
35
+ from numpy .typing import NDArray
36
+
28
37
from pymatgen .core .structure import Structure
29
38
from pymatgen .core .trajectory import Trajectory
30
39
from pymatgen .io .jdftx .inputs import JDFTXInfile
@@ -68,6 +77,8 @@ class is written.
68
77
implemented_store_vars = (
69
78
"bandProjections" ,
70
79
"eigenvals" ,
80
+ "kpts" ,
81
+ "bandstructure" ,
71
82
)
72
83
73
84
@@ -104,17 +115,22 @@ class JDFTXOutputs:
104
115
0-based index will appear mimicking a principle quantum number (ie "0px" for first shell and "1px" for
105
116
second shell). The actual principal quantum number is not stored in the JDFTx output files and must be
106
117
inferred by the user.
118
+ kpts (list[np.ndarray]): A list of the k-points used in the calculation. Each k-point is a 3D numpy array.
119
+ wk_list (list[np.ndarray]): A list of the weights for the k-points used in the calculation.
107
120
"""
108
121
109
122
calc_dir : str | Path = field (init = True )
110
123
outfile_name : str | Path | None = field (init = True )
111
124
store_vars : list [str ] = field (default_factory = list , init = True )
112
125
paths : dict [str , Path ] = field (init = False )
113
126
outfile : JDFTXOutfile = field (init = False )
114
- bandProjections : np .ndarray | None = field (init = False )
115
- eigenvals : np .ndarray | None = field (init = False )
127
+ bandProjections : NDArray | None = field (init = False )
128
+ eigenvals : NDArray | None = field (init = False )
129
+ kpts : list [NDArray ] | None = field (init = False )
130
+ wk_list : list [NDArray ] | None = field (init = False )
116
131
# Misc metadata for interacting with the data
117
132
orb_label_list : tuple [str , ...] | None = field (init = False )
133
+ bandstructure : BandStructure | None = field (init = False )
118
134
119
135
@classmethod
120
136
def from_calc_dir (
@@ -136,13 +152,21 @@ def from_calc_dir(
136
152
Returns:
137
153
JDFTXOutputs: The JDFTXOutputs object.
138
154
"""
139
- if store_vars is None :
140
- store_vars = []
155
+ store_vars = cls ._check_store_vars (store_vars )
141
156
return cls (calc_dir = Path (calc_dir ), store_vars = store_vars , outfile_name = outfile_name )
142
157
143
158
def __post_init__ (self ):
144
159
self ._init_paths ()
145
160
self ._store_vars ()
161
+ self ._init_bandstructure ()
162
+
163
+ def _check_store_vars (store_vars : list [str ] | None ) -> list [str ]:
164
+ if store_vars is None :
165
+ return []
166
+ if "bandstructure" in store_vars :
167
+ store_vars += ["kpts" , "eigenvals" ]
168
+ store_vars .pop (store_vars .index ("bandstructure" ))
169
+ return list (set (store_vars ))
146
170
147
171
def _init_paths (self ):
148
172
self .paths = {}
@@ -208,7 +232,101 @@ def _store_eigenvals(self):
208
232
if "eigenvals" in self .paths :
209
233
self .eigenvals = np .fromfile (self .paths ["eigenvals" ])
210
234
nstates = int (len (self .eigenvals ) / self .outfile .nbands )
211
- self .eigenvals = self .eigenvals .reshape (nstates , self .outfile .nbands )
235
+ self .eigenvals = self .eigenvals .reshape (nstates , self .outfile .nbands ) * Ha_to_eV
236
+
237
+ def _check_kpts (self ):
238
+ if "bandProjections" in self .paths and self .paths ["bandProjections" ].exists ():
239
+ # TODO: Write kpt file inconsistency checking
240
+ return
241
+ if "kPts" in self .paths and self .paths ["kPts" ].exists ():
242
+ raise NotImplementedError ("kPts file parsing not yet implemented." )
243
+ raise RuntimeError ("No k-point data found in JDFTx output files." )
244
+
245
+ def _store_kpts (self ):
246
+ if "bandProjections" in self .paths and self .paths ["bandProjections" ].exists ():
247
+ wk_list , kpts_list = _parse_kptsfrom_bandprojections_file (self .paths ["bandProjections" ])
248
+ self .kpts = kpts_list
249
+ self .wk_list = wk_list
250
+
251
+ def _init_bandstructure (self ):
252
+ if True in [v is None for v in [self .kpts , self .eigenvals ]]:
253
+ return
254
+ kpoints = np .array (self .kpts )
255
+ eigenvals = self ._get_pmg_eigenvals ()
256
+ projections = None
257
+ if self .bandProjections is not None :
258
+ projections = self ._get_pmg_projections ()
259
+ self .bandstructure = BandStructure (
260
+ kpoints ,
261
+ eigenvals ,
262
+ Lattice (self .outfile .structure .lattice .reciprocal_lattice .matrix ),
263
+ self .outfile .mu ,
264
+ projections = projections ,
265
+ structure = self .outfile .structure ,
266
+ )
267
+
268
+ def _get_lmax (self ) -> tuple [str | None , int | None ]:
269
+ """Get the maximum l quantum number and projection array orbital length.
270
+
271
+ Returns:
272
+ tuple[str | None, int | None]: The maximum l quantum number and projection array orbital length.
273
+ Both are None if no bandProjections file is available.
274
+ """
275
+ if self .orb_label_list is None :
276
+ return None , None
277
+ orbs = [label .split ("(" )[1 ].split (")" )[0 ] for label in self .orb_label_list ]
278
+ if orb_ref_list [- 1 ][0 ] in orbs :
279
+ return "f" , 16
280
+ if orb_ref_list [- 2 ][0 ] in orbs :
281
+ return "d" , 9
282
+ if orb_ref_list [- 3 ][0 ] in orbs :
283
+ return "p" , 4
284
+ if orb_ref_list [- 4 ][0 ] in orbs :
285
+ return "s" , 1
286
+ raise ValueError ("Unrecognized orbital labels in orb_label_list." )
287
+
288
+ def _get_pmg_eigenvals (self ) -> dict | None :
289
+ if self .eigenvals is None :
290
+ return None
291
+ _e_skj = self .eigenvals .copy ().reshape (self .outfile .nspin , - 1 , self .outfile .nbands )
292
+ _e_sjk = np .swapaxes (_e_skj , 1 , 2 )
293
+ spins = [Spin .up , Spin .down ]
294
+ eigenvals = {}
295
+ for i in range (self .outfile .nspin ):
296
+ eigenvals [spins [i ]] = _e_sjk [i ]
297
+ return eigenvals
298
+
299
+ def _get_pmg_projections (self ) -> dict | None :
300
+ """Return pymatgen-compatible projections dictionary.
301
+
302
+ Converts the bandProjections array to a pymatgen-compatible dictionary
303
+
304
+ Returns:
305
+ dict | None:
306
+ """
307
+ lmax , norbmax = self ._get_lmax ()
308
+ if norbmax is None :
309
+ return None
310
+ if self .orb_label_list is None :
311
+ return None
312
+ _proj_tju = self .bandProjections .copy ()
313
+ if _proj_tju .dtype is np .complex64 :
314
+ # Convert <orb|band(state)> to |<orb|band(state)>|^2
315
+ _proj_tju = np .abs (_proj_tju ) ** 2
316
+ # Convert to standard datatype - using np.real to suppress warnings
317
+ proj_tju = np .array (np .real (_proj_tju ), dtype = float )
318
+ proj_skju = proj_tju .reshape ([self .outfile .nspin , - 1 , self .outfile .nbands , len (self .orb_label_list )])
319
+ proj_sjku = np .swapaxes (proj_skju , 1 , 2 )
320
+ nspin , nbands , nkpt , nproj = proj_sjku .shape
321
+ u_to_oa_map = _get_u_to_oa_map (self .paths ["bandProjections" ])
322
+ projections = {}
323
+ spins = [Spin .up , Spin .down ]
324
+ for i in range (nspin ):
325
+ projections [spins [i ]] = np .zeros ([nbands , nkpt , norbmax , len (self .outfile .structure )])
326
+ # TODO: Consider jitting this loop
327
+ for u in range (nproj ):
328
+ projections [spins [i ]][:, :, * u_to_oa_map [u ]] += proj_sjku [i , :, :, u ]
329
+ return projections
212
330
213
331
214
332
_jof_atr_from_last_slice = (
0 commit comments