Skip to content

Commit fc5286e

Browse files
Remove LOBSTER output file trailing line sensitivity (#4448)
* clean up tests * remove trailing line sensitivity * simplify temp file * lint fix * add test * fix filename Co-authored-by: Aakash Ashok Naik <[email protected]> Signed-off-by: Haoyu (Daniel) YANG 杨浩宇 <[email protected]> * also test default filename * test `_get_lines` --------- Signed-off-by: Haoyu (Daniel) YANG 杨浩宇 <[email protected]> Co-authored-by: Aakash Ashok Naik <[email protected]>
1 parent 2f2c8d7 commit fc5286e

File tree

2 files changed

+67
-59
lines changed

2 files changed

+67
-59
lines changed

src/pymatgen/io/lobster/outputs.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
def _get_lines(filename) -> list[str]:
5959
with zopen(filename, mode="rt", encoding="utf-8") as file:
60-
return file.read().split("\n") # type:ignore[return-value,arg-type]
60+
return cast("list[str]", file.read().splitlines())
6161

6262

6363
class Cohpcar:
@@ -109,7 +109,7 @@ def __init__(
109109
or (are_coops and are_multi_center_cobis)
110110
or (are_cobis and are_multi_center_cobis)
111111
):
112-
raise ValueError("You cannot have info about COOPs, COBIs and/or multi-center COBIS in the same file.")
112+
raise ValueError("You cannot have info about COOPs, COBIs and/or multi-center COBIs in the same file.")
113113

114114
self.are_coops = are_coops
115115
self.are_cobis = are_cobis
@@ -125,7 +125,7 @@ def __init__(
125125
else:
126126
self._filename = "COHPCAR.lobster"
127127

128-
lines = _get_lines(filename)
128+
lines: list[str] = _get_lines(self._filename)
129129

130130
# The parameters line is the second line in a COHPCAR file.
131131
# It contains all parameters that are needed to map the file.
@@ -136,24 +136,23 @@ def __init__(
136136
self.is_spin_polarized = int(parameters[1]) == 2
137137
spins = [Spin.up, Spin.down] if int(parameters[1]) == 2 else [Spin.up]
138138
cohp_data: dict[str, dict[str, Any]] = {}
139+
140+
# The COHP/COBI data start from line num_bonds + 3
141+
data = np.array([np.array(line.split(), dtype=float) for line in lines[num_bonds + 3 :]]).transpose()
142+
139143
if not self.are_multi_center_cobis:
140-
# The COHP data start in line num_bonds + 3
141-
data = np.array([np.array(line.split(), dtype=float) for line in lines[num_bonds + 3 :]]).transpose()
142144
cohp_data = {
143145
"average": {
144146
"COHP": {spin: data[1 + 2 * s * (num_bonds + 1)] for s, spin in enumerate(spins)},
145147
"ICOHP": {spin: data[2 + 2 * s * (num_bonds + 1)] for s, spin in enumerate(spins)},
146148
}
147149
}
148-
else:
149-
# The COBI data start in line num_bonds + 3 if multi-center cobis exist
150-
data = np.array([np.array(line.split(), dtype=float) for line in lines[num_bonds + 3 :]]).transpose()
151150

152151
self.energies = data[0]
153152

154153
orb_cohp: dict[str, Any] = {}
155154
# Present for LOBSTER versions older than 2.2.0
156-
very_old = False
155+
older_than_2_2_0: bool = False
157156

158157
# The label has to be changed: there are more than one COHP for each atom combination
159158
# this is done to make the labeling consistent with ICOHPLIST.lobster
@@ -192,8 +191,8 @@ def __init__(
192191
else:
193192
# Present for LOBSTER versions older than 2.2.0
194193
if bond_num == 0:
195-
very_old = True
196-
if very_old:
194+
older_than_2_2_0 = True
195+
if older_than_2_2_0:
197196
bond_num += 1
198197
label = str(bond_num)
199198

@@ -245,8 +244,8 @@ def __init__(
245244
else:
246245
# Present for LOBSTER versions older than 2.2.0
247246
if bond_num == 0:
248-
very_old = True
249-
if very_old:
247+
older_than_2_2_0 = True
248+
if older_than_2_2_0:
250249
bond_num += 1
251250
label = str(bond_num)
252251

@@ -261,7 +260,7 @@ def __init__(
261260
}
262261

263262
# Present for LOBSTER older than 2.2.0
264-
if very_old:
263+
if older_than_2_2_0:
265264
for bond_str in orb_cohp:
266265
cohp_data[bond_str] = {
267266
"COHP": None,
@@ -405,14 +404,10 @@ def __init__(
405404
else:
406405
self._filename = "ICOHPLIST.lobster"
407406

408-
# LOBSTER list files have an extra trailing blank line
409-
# and we don't need the header.
410407
if self._icohpcollection is None:
411408
with zopen(self._filename, mode="rt", encoding="utf-8") as file:
412-
all_lines: list[str] = file.read().splitlines() # type:ignore[assignment]
409+
all_lines: list[str] = cast("list[str]", file.read().splitlines())
413410

414-
# strip *trailing* blank lines only
415-
all_lines = [line for line in all_lines if line.strip()]
416411
# --- detect header length robustly ---
417412
header_len = 0
418413
try:
@@ -442,7 +437,7 @@ def __init__(
442437
# If the calculation is spin polarized, the line in the middle
443438
# of the file will be another header line.
444439
# TODO: adapt this for orbital-wise stuff
445-
if version in ("3.1.1", "2.2.1"):
440+
if version in {"3.1.1", "2.2.1"}:
446441
self.is_spin_polarized = "distance" in lines[len(lines) // 2]
447442
else: # if version == "5.1.0":
448443
self.is_spin_polarized = len(lines[0].split()) == 9
@@ -637,10 +632,8 @@ def __init__(self, filename: PathLike = "NcICOBILIST.lobster") -> None:
637632
Args:
638633
filename: Name of the NcICOBILIST file.
639634
"""
640-
641-
# LOBSTER list files have an extra trailing blank line
642-
# and we don't need the header
643-
lines = _get_lines(filename)[1:-1]
635+
# We don't need the header
636+
lines = _get_lines(filename)[1:]
644637
if len(lines) == 0:
645638
raise RuntimeError("NcICOBILIST file contains no data.")
646639

@@ -930,7 +923,7 @@ def __init__(
930923
self.loewdin = [] if loewdin is None else loewdin
931924

932925
if self.num_atoms is None:
933-
lines = _get_lines(filename)[3:-3] # type:ignore[arg-type,assignment]
926+
lines = _get_lines(filename)[3:-2]
934927
if len(lines) == 0:
935928
raise RuntimeError("CHARGES file contains no data.")
936929

@@ -1105,10 +1098,12 @@ def __init__(self, filename: PathLike | None, **kwargs) -> None:
11051098
self.has_doscar_lso = (
11061099
"writing DOSCAR.LSO.lobster..." in lines and "SKIPPING writing DOSCAR.LSO.lobster..." not in lines
11071100
)
1101+
11081102
try:
11091103
version_number = float(".".join(self.lobster_version.strip("v").split(".")[:2]))
11101104
except ValueError:
11111105
version_number = 0.0
1106+
11121107
if version_number < 5.1:
11131108
self.has_cohpcar = (
11141109
"writing COOPCAR.lobster and ICOOPLIST.lobster..." in lines
@@ -1452,9 +1447,7 @@ def __init__(
14521447
for name in os.listdir(filenames):
14531448
if fnmatch.fnmatch(name, "FATBAND_*.lobster"):
14541449
filenames_new.append(os.path.join(filenames, name))
1455-
filenames = filenames_new # type: ignore[assignment]
1456-
1457-
filenames = cast("list[PathLike]", filenames)
1450+
filenames = cast("list[PathLike]", filenames_new)
14581451

14591452
if len(filenames) == 0:
14601453
raise ValueError("No FATBAND files in folder or given")
@@ -1546,7 +1539,7 @@ def __init__(
15461539

15471540
idx_kpt = -1
15481541
linenumber = iband = 0
1549-
for line in lines[1:-1]:
1542+
for line in lines[1:]:
15501543
if line.split()[0] == "#":
15511544
KPOINT = np.array(
15521545
[
@@ -1600,7 +1593,7 @@ def get_bandstructure(self) -> LobsterBandStructureSymmLine:
16001593
lattice=self.lattice,
16011594
efermi=self.efermi, # type: ignore[arg-type]
16021595
labels_dict=self.label_dict,
1603-
structure=self.structure, # type:ignore[arg-type]
1596+
structure=self.structure, # type: ignore[arg-type]
16041597
projections=self.p_eigenvals,
16051598
)
16061599

@@ -2159,7 +2152,7 @@ def __init__(
21592152
self._filename = filename
21602153
self.ewald_splitting = float(lines[0].split()[9])
21612154

2162-
lines = lines[5:-1]
2155+
lines = lines[5:]
21632156
self.num_atoms = len(lines) - 2
21642157
for atom in range(self.num_atoms):
21652158
line_parts = lines[atom].split()
@@ -2305,7 +2298,7 @@ def __init__(
23052298

23062299
self._filename = str(filename)
23072300
with zopen(self._filename, mode="rt", encoding="utf-8") as file:
2308-
lines: list[str] = file.readlines() # type:ignore[assignment]
2301+
lines: list[str] = cast("list[str]", file.readlines())
23092302
if len(lines) == 0:
23102303
raise RuntimeError("Please check provided input file, it seems to be empty")
23112304

tests/io/lobster/test_outputs.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import copy
4+
import gzip
45
import os
5-
import tempfile
66

77
import numpy as np
88
import orjson
@@ -30,6 +30,7 @@
3030
SitePotential,
3131
Wavefunction,
3232
)
33+
from pymatgen.io.lobster.outputs import _get_lines
3334
from pymatgen.io.vasp import Vasprun
3435
from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, MatSciTest
3536

@@ -64,7 +65,17 @@ def setup_method(self):
6465
filename=f"{TEST_DIR}/COOPCAR.lobster.BiSe.gz",
6566
are_coops=True,
6667
)
67-
self.cohp_fe = Cohpcar(filename=f"{TEST_DIR}/COOPCAR.lobster.gz")
68+
69+
# Make sure Cohpcar also works with terminating line ending char
70+
gz_path = f"{TEST_DIR}/COOPCAR.lobster.gz"
71+
with gzip.open(gz_path, "rt", encoding="utf-8") as f:
72+
content = f.read() + "\n"
73+
74+
# Test default filename (None should be redirected to "COHPCAR.lobster")
75+
with open("COHPCAR.lobster", "w", encoding="utf-8") as f:
76+
f.write(content)
77+
78+
self.cohp_fe = Cohpcar(filename=None)
6879
self.coop_fe = Cohpcar(
6980
filename=f"{TEST_DIR}/COOPCAR.lobster.gz",
7081
are_coops=True,
@@ -645,16 +656,11 @@ def setup_method(self):
645656
self.charge_lcfo = Charge(filename=f"{TEST_DIR}/CHARGE.LCFO.lobster.ALN.gz", is_lcfo=True)
646657

647658
def test_attributes(self):
648-
charge_Loewdin = [-1.25, 1.25]
649-
charge_Mulliken = [-1.30, 1.30]
650-
atomlist = ["O1", "Mn2"]
651-
types = ["O", "Mn"]
652-
num_atoms = 2
653-
assert charge_Mulliken == self.charge2.mulliken
654-
assert charge_Loewdin == self.charge2.loewdin
655-
assert atomlist == self.charge2.atomlist
656-
assert types == self.charge2.types
657-
assert num_atoms == self.charge2.num_atoms
659+
assert self.charge2.mulliken == approx([-1.30, 1.30])
660+
assert self.charge2.loewdin == approx([-1.25, 1.25])
661+
assert self.charge2.atomlist == ["O1", "Mn2"]
662+
assert self.charge2.types == ["O", "Mn"]
663+
assert self.charge2.num_atoms == 2
658664

659665
# test with CHARG.LCFO.lobster file
660666
assert self.charge_lcfo.is_lcfo
@@ -1866,7 +1872,7 @@ def test_msonable(self):
18661872
assert getattr(grosspop_from_dict, attr_name) == attr_value
18671873

18681874

1869-
class TestIcohplist:
1875+
class TestIcohplist(MatSciTest):
18701876
def setup_method(self):
18711877
self.icohp_bise = Icohplist(filename=f"{TEST_DIR}/ICOHPLIST.lobster.BiSe")
18721878
self.icoop_bise = Icohplist(
@@ -2173,21 +2179,16 @@ def test_msonable(self):
21732179
assert getattr(icohplist_from_dict, attr_name) == attr_value
21742180

21752181
def test_missing_trailing_newline(self):
2176-
content = (
2177-
"1 Co1 O1 1.00000 0 0 0 -0.50000 -1.00000\n"
2178-
"2 Co2 O2 1.10000 0 0 0 -0.60000 -1.10000"
2179-
)
2182+
fname = f"{self.tmp_path}/icohplist"
2183+
with open(fname, mode="w", encoding="utf-8") as f:
2184+
f.write(
2185+
"1 Co1 O1 1.00000 0 0 0 -0.50000 -1.00000\n"
2186+
"2 Co2 O2 1.10000 0 0 0 -0.60000 -1.10000"
2187+
)
21802188

2181-
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
2182-
tmp.write(content)
2183-
tmp.flush()
2184-
fname = tmp.name
2185-
try:
2186-
ip = Icohplist(filename=fname)
2187-
assert len(ip.icohplist) == 2
2188-
assert ip.icohplist["1"]["icohp"][Spin.up] == approx(-0.5)
2189-
finally:
2190-
os.remove(fname)
2189+
ip = Icohplist(filename=fname)
2190+
assert len(ip.icohplist) == 2
2191+
assert ip.icohplist["1"]["icohp"][Spin.up] == approx(-0.5)
21912192

21922193

21932194
class TestNciCobiList:
@@ -2531,3 +2532,17 @@ def test_attributes(self):
25312532
"abs": 56.14,
25322533
"unit": "uC/cm2",
25332534
}
2535+
2536+
2537+
def test_get_lines():
2538+
"""Ensure `_get_lines` is not trailing end char sensitive."""
2539+
with open("without-end-char", mode="wb") as f:
2540+
f.write(b"first line\nsecond line")
2541+
2542+
with open("with-end-char", mode="wb") as f:
2543+
f.write(b"first line\nsecond line\n")
2544+
2545+
without_end_char = _get_lines("without-end-char")
2546+
with_end_char = _get_lines("with-end-char")
2547+
2548+
assert len(with_end_char) == len(without_end_char) == 2

0 commit comments

Comments
 (0)