Skip to content

Commit 12c573a

Browse files
authored
bandstructure attribute for JDFTOutputs - initializes to BandStructure if successfully stores 'eigenvals' and 'kpts'. Initializes projections in BandStructure if successfully stores 'bandprojections'. (#4413)
1 parent e82356d commit 12c573a

File tree

3 files changed

+258
-6
lines changed

3 files changed

+258
-6
lines changed

src/pymatgen/io/jdftx/_output_utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import numpy as np
1717

18+
from pymatgen.electronic_structure.core import Orbital
19+
1820
if TYPE_CHECKING:
1921
from collections.abc import Callable
2022

@@ -476,6 +478,34 @@ def get_proj_tju_from_file(bandfile_filepath: Path | str) -> NDArray[np.float32
476478
return _parse_bandfile_complex(bandfile_filepath) if is_complex else _parse_bandfile_normalized(bandfile_filepath)
477479

478480

481+
def _parse_kptsfrom_bandprojections_file(bandfile_filepath: str | Path) -> tuple[list[float], list[NDArray]]:
482+
"""Parse kpts from bandprojections file.
483+
484+
Parse kpts from bandprojections file.
485+
486+
Args:
487+
bandfile_filepath (Path | str): Path to bandprojections file.
488+
489+
Returns:
490+
tuple[list[float], list[np.ndarray[float]]]: Tuple of k-point weights and k-points
491+
"""
492+
wk_list: list[float] = []
493+
k_points_list: list[NDArray] = []
494+
kpt_lines = []
495+
with open(bandfile_filepath) as f:
496+
for line in f:
497+
if line.startswith("#") and ";" in line:
498+
_line = line.split(";")[0].lstrip("#")
499+
kpt_lines.append(_line)
500+
for line in kpt_lines:
501+
k_points = line.split("[")[1].split("]")[0].strip().split()
502+
_k_points_floats: list[float] = [float(v) for v in k_points]
503+
k_points_list.append(np.array(_k_points_floats))
504+
wk = float(line.split("]")[1].strip().split()[0])
505+
wk_list.append(wk)
506+
return wk_list, k_points_list
507+
508+
479509
def _is_complex_bandfile_filepath(bandfile_filepath: str | Path) -> bool:
480510
"""Determine if bandprojections file is complex.
481511
@@ -507,6 +537,64 @@ def _is_complex_bandfile_filepath(bandfile_filepath: str | Path) -> bool:
507537
["dxy", "dyz", "dz2", "dxz", "dx2-y2"],
508538
["fy(3x2-y2)", "fxyz", "fyz2", "fz3", "fxz2", "fz(x2-y2)", "fx(x2-3y2)"],
509539
]
540+
orb_ref_to_o_dict = {
541+
"s": int(Orbital.s),
542+
"py": int(Orbital.py),
543+
"pz": int(Orbital.pz),
544+
"px": int(Orbital.px),
545+
"dxy": int(Orbital.dxy),
546+
"dyz": int(Orbital.dyz),
547+
"dz2": int(Orbital.dz2),
548+
"dxz": int(Orbital.dxz),
549+
"dx2-y2": int(Orbital.dx2),
550+
# Keep the f-orbitals arbitrary-ish until they get designated names in pymatgen.
551+
orb_ref_list[-1][0]: int(Orbital.f_3),
552+
orb_ref_list[-1][1]: int(Orbital.f_2),
553+
orb_ref_list[-1][2]: int(Orbital.f_1),
554+
orb_ref_list[-1][3]: int(Orbital.f0),
555+
orb_ref_list[-1][4]: int(Orbital.f1),
556+
orb_ref_list[-1][5]: int(Orbital.f2),
557+
}
558+
559+
560+
def _get_atom_orb_labels_map_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
561+
"""
562+
Return a dictionary mapping each atom symbol to pymatgen-compatible orbital projection string representations.
563+
564+
Identical to _get_atom_orb_labels_ref_dict, but doesn't include the numbers in the labels.
565+
566+
567+
568+
Args:
569+
bandfile_filepath (str | Path): The path to the bandfile.
570+
571+
Returns:
572+
dict[str, list[str]]: A dictionary mapping each atom symbol to all atomic orbital projection string
573+
representations.
574+
"""
575+
bandfile = read_file(bandfile_filepath)
576+
labels_dict: dict[str, list[str]] = {}
577+
578+
for i, line in enumerate(bandfile):
579+
if i > 1:
580+
if "#" in line:
581+
break
582+
lsplit = line.strip().split()
583+
sym = lsplit[0]
584+
labels_dict[sym] = []
585+
lmax = int(lsplit[3])
586+
# Would prefer to use "l" rather than "L" here (as uppercase "L" means something else entirely) but
587+
# pr*-c*mm*t thinks "l" is an ambiguous variable name.
588+
for L in range(lmax + 1):
589+
mls = orb_ref_list[L]
590+
nshells = int(lsplit[4 + L])
591+
for _n in range(nshells):
592+
if nshells > 1:
593+
for ml in mls:
594+
labels_dict[sym].append(f"{ml}")
595+
else:
596+
labels_dict[sym] += mls
597+
return labels_dict
510598

511599

512600
def _get_atom_orb_labels_ref_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
@@ -623,6 +711,28 @@ def _get_orb_label(ion: str, idx: int, orb: str) -> str:
623711
return f"{ion}#{idx + 1}({orb})"
624712

625713

714+
def _get_u_to_oa_map(bandfile_filepath: Path) -> list[tuple[int, int]]:
715+
"""
716+
Return a list, where the u'th element is a tuple of the atomic orbital index and the ion index.
717+
718+
Args:
719+
bandfile_filepath (str | Path): The path to the bandfile.
720+
721+
Returns:
722+
list[tuple[int, int]]: A list, where the u'th element is a tuple of the atomic orbital index and the ion index.
723+
"""
724+
map_labels_dict = _get_atom_orb_labels_map_dict(bandfile_filepath)
725+
atom_count_list = _get_atom_count_list(bandfile_filepath)
726+
u_to_oa_map = []
727+
a = 0
728+
for ion, ion_count in atom_count_list:
729+
for _i in range(ion_count):
730+
for orb in map_labels_dict[ion]:
731+
u_to_oa_map.append((orb_ref_to_o_dict[orb], a))
732+
a += 1
733+
return u_to_oa_map
734+
735+
626736
def _get_orb_label_list(bandfile_filepath: Path) -> tuple[str, ...]:
627737
"""
628738
Return a tuple of all atomic orbital projection string representations.

src/pymatgen/io/jdftx/outputs.py

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,24 @@ class is written.
1616
import numpy as np
1717
from monty.dev import deprecated
1818

19+
from pymatgen.core import Lattice
20+
from pymatgen.core.units import Ha_to_eV
21+
from pymatgen.electronic_structure.bandstructure import BandStructure
22+
from pymatgen.electronic_structure.core import Spin
1923
from pymatgen.io.jdftx._output_utils import (
2024
_get_nbands_from_bandfile_filepath,
2125
_get_orb_label_list,
26+
_get_u_to_oa_map,
27+
_parse_kptsfrom_bandprojections_file,
2228
get_proj_tju_from_file,
29+
orb_ref_list,
2330
read_outfile_slices,
2431
)
2532
from pymatgen.io.jdftx.jdftxoutfileslice import JDFTXOutfileSlice
2633

2734
if TYPE_CHECKING:
35+
from numpy.typing import NDArray
36+
2837
from pymatgen.core.structure import Structure
2938
from pymatgen.core.trajectory import Trajectory
3039
from pymatgen.io.jdftx.inputs import JDFTXInfile
@@ -68,6 +77,8 @@ class is written.
6877
implemented_store_vars = (
6978
"bandProjections",
7079
"eigenvals",
80+
"kpts",
81+
"bandstructure",
7182
)
7283

7384

@@ -104,17 +115,22 @@ class JDFTXOutputs:
104115
0-based index will appear mimicking a principle quantum number (ie "0px" for first shell and "1px" for
105116
second shell). The actual principal quantum number is not stored in the JDFTx output files and must be
106117
inferred by the user.
118+
kpts (list[np.ndarray]): A list of the k-points used in the calculation. Each k-point is a 3D numpy array.
119+
wk_list (list[np.ndarray]): A list of the weights for the k-points used in the calculation.
107120
"""
108121

109122
calc_dir: str | Path = field(init=True)
110123
outfile_name: str | Path | None = field(init=True)
111124
store_vars: list[str] = field(default_factory=list, init=True)
112125
paths: dict[str, Path] = field(init=False)
113126
outfile: JDFTXOutfile = field(init=False)
114-
bandProjections: np.ndarray | None = field(init=False)
115-
eigenvals: np.ndarray | None = field(init=False)
127+
bandProjections: NDArray | None = field(init=False)
128+
eigenvals: NDArray | None = field(init=False)
129+
kpts: list[NDArray] | None = field(init=False)
130+
wk_list: list[NDArray] | None = field(init=False)
116131
# Misc metadata for interacting with the data
117132
orb_label_list: tuple[str, ...] | None = field(init=False)
133+
bandstructure: BandStructure | None = field(init=False)
118134

119135
@classmethod
120136
def from_calc_dir(
@@ -136,13 +152,21 @@ def from_calc_dir(
136152
Returns:
137153
JDFTXOutputs: The JDFTXOutputs object.
138154
"""
139-
if store_vars is None:
140-
store_vars = []
155+
store_vars = cls._check_store_vars(store_vars)
141156
return cls(calc_dir=Path(calc_dir), store_vars=store_vars, outfile_name=outfile_name)
142157

143158
def __post_init__(self):
144159
self._init_paths()
145160
self._store_vars()
161+
self._init_bandstructure()
162+
163+
def _check_store_vars(store_vars: list[str] | None) -> list[str]:
164+
if store_vars is None:
165+
return []
166+
if "bandstructure" in store_vars:
167+
store_vars += ["kpts", "eigenvals"]
168+
store_vars.pop(store_vars.index("bandstructure"))
169+
return list(set(store_vars))
146170

147171
def _init_paths(self):
148172
self.paths = {}
@@ -208,7 +232,101 @@ def _store_eigenvals(self):
208232
if "eigenvals" in self.paths:
209233
self.eigenvals = np.fromfile(self.paths["eigenvals"])
210234
nstates = int(len(self.eigenvals) / self.outfile.nbands)
211-
self.eigenvals = self.eigenvals.reshape(nstates, self.outfile.nbands)
235+
self.eigenvals = self.eigenvals.reshape(nstates, self.outfile.nbands) * Ha_to_eV
236+
237+
def _check_kpts(self):
238+
if "bandProjections" in self.paths and self.paths["bandProjections"].exists():
239+
# TODO: Write kpt file inconsistency checking
240+
return
241+
if "kPts" in self.paths and self.paths["kPts"].exists():
242+
raise NotImplementedError("kPts file parsing not yet implemented.")
243+
raise RuntimeError("No k-point data found in JDFTx output files.")
244+
245+
def _store_kpts(self):
246+
if "bandProjections" in self.paths and self.paths["bandProjections"].exists():
247+
wk_list, kpts_list = _parse_kptsfrom_bandprojections_file(self.paths["bandProjections"])
248+
self.kpts = kpts_list
249+
self.wk_list = wk_list
250+
251+
def _init_bandstructure(self):
252+
if True in [v is None for v in [self.kpts, self.eigenvals]]:
253+
return
254+
kpoints = np.array(self.kpts)
255+
eigenvals = self._get_pmg_eigenvals()
256+
projections = None
257+
if self.bandProjections is not None:
258+
projections = self._get_pmg_projections()
259+
self.bandstructure = BandStructure(
260+
kpoints,
261+
eigenvals,
262+
Lattice(self.outfile.structure.lattice.reciprocal_lattice.matrix),
263+
self.outfile.mu,
264+
projections=projections,
265+
structure=self.outfile.structure,
266+
)
267+
268+
def _get_lmax(self) -> tuple[str | None, int | None]:
269+
"""Get the maximum l quantum number and projection array orbital length.
270+
271+
Returns:
272+
tuple[str | None, int | None]: The maximum l quantum number and projection array orbital length.
273+
Both are None if no bandProjections file is available.
274+
"""
275+
if self.orb_label_list is None:
276+
return None, None
277+
orbs = [label.split("(")[1].split(")")[0] for label in self.orb_label_list]
278+
if orb_ref_list[-1][0] in orbs:
279+
return "f", 16
280+
if orb_ref_list[-2][0] in orbs:
281+
return "d", 9
282+
if orb_ref_list[-3][0] in orbs:
283+
return "p", 4
284+
if orb_ref_list[-4][0] in orbs:
285+
return "s", 1
286+
raise ValueError("Unrecognized orbital labels in orb_label_list.")
287+
288+
def _get_pmg_eigenvals(self) -> dict | None:
289+
if self.eigenvals is None:
290+
return None
291+
_e_skj = self.eigenvals.copy().reshape(self.outfile.nspin, -1, self.outfile.nbands)
292+
_e_sjk = np.swapaxes(_e_skj, 1, 2)
293+
spins = [Spin.up, Spin.down]
294+
eigenvals = {}
295+
for i in range(self.outfile.nspin):
296+
eigenvals[spins[i]] = _e_sjk[i]
297+
return eigenvals
298+
299+
def _get_pmg_projections(self) -> dict | None:
300+
"""Return pymatgen-compatible projections dictionary.
301+
302+
Converts the bandProjections array to a pymatgen-compatible dictionary
303+
304+
Returns:
305+
dict | None:
306+
"""
307+
lmax, norbmax = self._get_lmax()
308+
if norbmax is None:
309+
return None
310+
if self.orb_label_list is None:
311+
return None
312+
_proj_tju = self.bandProjections.copy()
313+
if _proj_tju.dtype is np.complex64:
314+
# Convert <orb|band(state)> to |<orb|band(state)>|^2
315+
_proj_tju = np.abs(_proj_tju) ** 2
316+
# Convert to standard datatype - using np.real to suppress warnings
317+
proj_tju = np.array(np.real(_proj_tju), dtype=float)
318+
proj_skju = proj_tju.reshape([self.outfile.nspin, -1, self.outfile.nbands, len(self.orb_label_list)])
319+
proj_sjku = np.swapaxes(proj_skju, 1, 2)
320+
nspin, nbands, nkpt, nproj = proj_sjku.shape
321+
u_to_oa_map = _get_u_to_oa_map(self.paths["bandProjections"])
322+
projections = {}
323+
spins = [Spin.up, Spin.down]
324+
for i in range(nspin):
325+
projections[spins[i]] = np.zeros([nbands, nkpt, norbmax, len(self.outfile.structure)])
326+
# TODO: Consider jitting this loop
327+
for u in range(nproj):
328+
projections[spins[i]][:, :, *u_to_oa_map[u]] += proj_sjku[i, :, :, u]
329+
return projections
212330

213331

214332
_jof_atr_from_last_slice = (

tests/io/jdftx/test_jdftxoutput.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
if TYPE_CHECKING:
2020
from pathlib import Path
2121

22+
_implemented_store_vars = [v for v in implemented_store_vars if v != "bandstructure"]
23+
2224

2325
@pytest.mark.parametrize(
2426
("calc_dir", "known_paths"),
@@ -34,6 +36,7 @@ def test_known_paths(calc_dir: Path, known_paths: dict):
3436
@pytest.mark.parametrize(
3537
("calc_dir", "store_vars"),
3638
[
39+
(n2_ex_calc_dir, ["bandProjections", "eigenvals", "kpts"]),
3740
(n2_ex_calc_dir, ["bandProjections", "eigenvals"]),
3841
(n2_ex_calc_dir, ["bandProjections"]),
3942
(n2_ex_calc_dir, []),
@@ -42,14 +45,35 @@ def test_known_paths(calc_dir: Path, known_paths: dict):
4245
def test_store_vars(calc_dir: Path, store_vars: list[str]):
4346
"""Test that the stored variables are correct."""
4447
jo = JDFTXOutputs.from_calc_dir(calc_dir, store_vars=store_vars)
45-
for var in implemented_store_vars:
48+
for var in _implemented_store_vars:
4649
assert hasattr(jo, var)
4750
if var in store_vars:
4851
assert getattr(jo, var) is not None
4952
else:
5053
assert getattr(jo, var) is None
5154

5255

56+
@pytest.mark.parametrize(
57+
("calc_dir", "store_vars"),
58+
[
59+
(n2_ex_calc_dir, ["bandProjections", "eigenvals", "kpts", "bandstructure"]),
60+
(n2_ex_calc_dir, ["bandProjections", "bandstructure"]),
61+
(n2_ex_calc_dir, ["bandstructure"]),
62+
],
63+
)
64+
def test_bandstructure_construction(calc_dir: Path, store_vars: list[str]):
65+
"""Test that the stored variables are correct."""
66+
jo = JDFTXOutputs.from_calc_dir(calc_dir, store_vars=store_vars)
67+
assert hasattr(jo, "bandstructure")
68+
required_for_bandstructure = ["eigenvals", "kpts"]
69+
for var in required_for_bandstructure:
70+
assert hasattr(jo, var)
71+
if "bandProjections" in store_vars:
72+
assert len(jo.bandstructure.projections)
73+
else:
74+
assert not len(jo.bandstructure.projections)
75+
76+
5377
@pytest.mark.parametrize(
5478
("calc_dir", "known_metadata"),
5579
[

0 commit comments

Comments
 (0)