Skip to content

Commit 885603c

Browse files
committed
Type fixes.
1 parent e9c4371 commit 885603c

File tree

5 files changed

+42
-50
lines changed

5 files changed

+42
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ ignore_missing_imports = true
302302
namespace_packages = true
303303
no_implicit_optional = false
304304
disable_error_code = ["annotation-unchecked", "override", "operator", "attr-defined", "union-attr", "misc", "call-overload"]
305-
exclude = ['src/pymatgen/analysis', 'src/pymatgen/phonon', 'src/pymatgen/io/lobster', 'src/pymatgen/io/cp2k', 'src/pymatgen/io/lammps']
305+
exclude = ['src/pymatgen/analysis', 'src/pymatgen/phonon', 'src/pymatgen/io/cp2k', 'src/pymatgen/io/lammps']
306306
plugins = ["numpy.typing.mypy_plugin"]
307307

308308
[[tool.mypy.overrides]]

src/pymatgen/io/lobster/inputs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def write_KPOINTS(
512512
# For now, we are setting MAGMOM to zero. (Taken from INCAR class)
513513
cell = matrix, positions, zs, magmoms
514514
# TODO: what about this shift?
515-
mapping, grid = spglib.get_ir_reciprocal_mesh(mesh, cell, is_shift=[0, 0, 0])
515+
mapping, grid = spglib.get_ir_reciprocal_mesh(mesh, cell, is_shift=[0, 0, 0]) # type:ignore[arg-type]
516516

517517
# Get the KPOINTS for the grid
518518
if isym == -1:
@@ -530,7 +530,7 @@ def write_KPOINTS(
530530
weights = []
531531
all_labels = []
532532
newlist = [list(gp) for gp in list(grid)]
533-
mapping = []
533+
mapping = [] # type:ignore[assignment]
534534
for gp in newlist:
535535
minus_gp = [-k for k in gp]
536536
if minus_gp in newlist and minus_gp != [0, 0, 0]:
@@ -590,7 +590,7 @@ def from_file(cls, lobsterin: PathLike) -> Self:
590590
Lobsterin object
591591
"""
592592
with zopen(lobsterin, mode="rt", encoding="utf-8") as file:
593-
lines = file.read().split("\n")
593+
lines: list[str] = file.read().split("\n") # type:ignore[arg-type,assignment]
594594
if not lines:
595595
raise RuntimeError("lobsterin file contains no data.")
596596

@@ -640,7 +640,7 @@ def _get_potcar_symbols(POTCAR_input: PathLike) -> list[str]:
640640
Returns:
641641
list[str]: names of the species
642642
"""
643-
potcar = Potcar.from_file(POTCAR_input)
643+
potcar = Potcar.from_file(POTCAR_input) # type:ignore[arg-type]
644644
for pot in potcar:
645645
if pot.potential_type != "PAW":
646646
raise ValueError("Lobster only works with PAW! Use different POTCARs")

src/pymatgen/io/lobster/lobsterenv.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
else:
205205
bv_analyzer = BVAnalyzer()
206206
try:
207-
self.valences = bv_analyzer.get_valences(structure=self.structure)
207+
self.valences = bv_analyzer.get_valences(structure=self.structure) # type:ignore[arg-type]
208208
except ValueError as exc:
209209
self.valences = None
210210
if additional_condition in {1, 3, 5, 6}:
@@ -311,7 +311,7 @@ def get_light_structure_environment(
311311
LobsterLightStructureEnvironments
312312
"""
313313
lgf = LocalGeometryFinder()
314-
lgf.setup_structure(structure=self.structure)
314+
lgf.setup_structure(structure=self.structure) # type:ignore[arg-type]
315315
list_ce_symbols = []
316316
list_csm = []
317317
list_permut = []
@@ -342,7 +342,7 @@ def get_light_structure_environment(
342342
list_ce_symbol=list_ce_symbols,
343343
list_csm=list_csm,
344344
list_permutation=list_permut,
345-
list_neighsite=self.list_neighsite,
345+
list_neighsite=self.list_neighsite, # type:ignore[arg-type]
346346
list_neighisite=self.list_neighisite,
347347
structure=self.structure,
348348
valences=self.valences,
@@ -384,7 +384,7 @@ def get_light_structure_environment(
384384
list_ce_symbol=new_list_ce_symbols,
385385
list_csm=new_list_csm,
386386
list_permutation=new_list_permut,
387-
list_neighsite=new_list_neighsite,
387+
list_neighsite=new_list_neighsite, # type:ignore[arg-type]
388388
list_neighisite=new_list_neighisite,
389389
structure=self.structure,
390390
valences=self.valences,
@@ -436,7 +436,7 @@ def get_info_icohps_to_neighbors(
436436
if idx in isites:
437437
for key, icohpsum in zip(self.list_keys[idx], self.list_icohps[idx], strict=True):
438438
summed_icohps += icohpsum
439-
list_icohps.append(icohpsum)
439+
list_icohps.append(icohpsum) # type:ignore[arg-type]
440440
labels.append(key)
441441
atoms.append(
442442
[
@@ -617,7 +617,7 @@ def get_info_cohps_to_neighbors(
617617

618618
summed_cohp = None
619619

620-
return plot_label, summed_cohp
620+
return plot_label, summed_cohp # type:ignore[return-value]
621621

622622
def _get_plot_label(self, atoms: list[list[str]], per_bond: bool) -> str:
623623
"""Count the types of bonds and append a label."""
@@ -682,8 +682,8 @@ def get_info_icohps_between_neighbors(
682682
unitcell1 = self._determine_unit_cell(n_site)
683683
unitcell2 = self._determine_unit_cell(n_site2)
684684

685-
index_n_site = self._get_original_site(self.structure, n_site)
686-
index_n_site2 = self._get_original_site(self.structure, n_site2)
685+
index_n_site = self._get_original_site(self.structure, n_site) # type:ignore[arg-type]
686+
index_n_site2 = self._get_original_site(self.structure, n_site2) # type:ignore[arg-type]
687687

688688
if index_n_site < index_n_site2:
689689
translation = list(np.array(unitcell1) - np.array(unitcell2))
@@ -913,7 +913,7 @@ def _find_environments(
913913
Returns:
914914
Tuple of ICOHPs, keys, lengths, neighisite, neighsite, coords.
915915
"""
916-
list_icohps: list[list[IcohpValue]] = []
916+
list_icohps: list[list[float]] = []
917917
list_keys: list[list[str]] = []
918918
list_lengths: list[list[float]] = []
919919
list_neighisite: list[list[int]] = []
@@ -1013,7 +1013,7 @@ def _find_environments(
10131013
list_lengths.append([])
10141014
list_keys.append([])
10151015
list_coords.append([])
1016-
return (
1016+
return ( # type:ignore[return-value]
10171017
list_icohps,
10181018
list_keys,
10191019
list_lengths,
@@ -1027,7 +1027,7 @@ def _find_relevant_atoms_additional_condition(
10271027
site_idx: int,
10281028
icohps: dict[str, IcohpValue],
10291029
additional_condition: Literal[0, 1, 2, 3, 4, 5, 6],
1030-
) -> tuple[list[str], list[float], list[int], list[IcohpValue]]:
1030+
) -> tuple[list[str], list[float], list[int], list[float]]:
10311031
"""Find all relevant atoms that fulfill the additional condition.
10321032
10331033
Args:
@@ -1041,7 +1041,7 @@ def _find_relevant_atoms_additional_condition(
10411041
keys_from_ICOHPs: list[str] = []
10421042
lengths_from_ICOHPs: list[float] = []
10431043
neighbors_from_ICOHPs: list[int] = []
1044-
icohps_from_ICOHPs: list[IcohpValue] = []
1044+
icohps_from_ICOHPs: list[float] = []
10451045

10461046
for key, icohp in icohps.items():
10471047
atomnr1 = self._get_atomnumber(icohp._atom1)
@@ -1443,15 +1443,15 @@ def from_Lobster(
14431443

14441444
if list_neighisite[site_idx] is not None:
14451445
nb_set = cls.NeighborsSet(
1446-
structure=structure,
1446+
structure=structure, # type:ignore[arg-type]
14471447
isite=site_idx,
14481448
all_nbs_sites=all_nbs_sites,
14491449
all_nbs_sites_indices=all_nbs_sites_indices[site_idx],
14501450
)
14511451

14521452
else:
14531453
nb_set = cls.NeighborsSet(
1454-
structure=structure,
1454+
structure=structure, # type:ignore[arg-type]
14551455
isite=site_idx,
14561456
all_nbs_sites=[],
14571457
all_nbs_sites_indices=[],

src/pymatgen/io/lobster/outputs.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555
)
5656

5757

58+
def _get_lines(filename) -> list[str]:
59+
with zopen(filename, mode="rt", encoding="utf-8") as file:
60+
return file.read().split("\n") # type:ignore[return-value,arg-type]
61+
62+
5863
class Cohpcar:
5964
"""Read COXXCAR.lobster/COXXCAR.LCFO.lobster files generated by LOBSTER.
6065
@@ -120,8 +125,7 @@ def __init__(
120125
else:
121126
self._filename = "COHPCAR.lobster"
122127

123-
with zopen(self._filename, mode="rt", encoding="utf-8") as file:
124-
lines = file.read().split("\n")
128+
lines = _get_lines(filename)
125129

126130
# The parameters line is the second line in a COHPCAR file.
127131
# It contains all parameters that are needed to map the file.
@@ -405,7 +409,7 @@ def __init__(
405409
# and we don't need the header.
406410
if self._icohpcollection is None:
407411
with zopen(self._filename, mode="rt", encoding="utf-8") as file:
408-
all_lines = file.read().splitlines()
412+
all_lines: list[str] = file.read().splitlines() # type:ignore[assignment]
409413

410414
# strip *trailing* blank lines only
411415
all_lines = [line for line in all_lines if line.strip()]
@@ -625,7 +629,7 @@ class NciCobiList:
625629
}
626630
"""
627631

628-
def __init__(self, filename: PathLike | None = "NcICOBILIST.lobster") -> None:
632+
def __init__(self, filename: PathLike = "NcICOBILIST.lobster") -> None:
629633
"""
630634
631635
LOBSTER < 4.1.0: no COBI/ICOBI/NcICOBI
@@ -636,8 +640,7 @@ def __init__(self, filename: PathLike | None = "NcICOBILIST.lobster") -> None:
636640

637641
# LOBSTER list files have an extra trailing blank line
638642
# and we don't need the header
639-
with zopen(filename, mode="rt", encoding="utf-8") as file:
640-
lines = file.read().split("\n")[1:-1]
643+
lines = _get_lines(filename)[1:-1]
641644
if len(lines) == 0:
642645
raise RuntimeError("NcICOBILIST file contains no data.")
643646

@@ -927,8 +930,7 @@ def __init__(
927930
self.loewdin = [] if loewdin is None else loewdin
928931

929932
if self.num_atoms is None:
930-
with zopen(filename, mode="rt", encoding="utf-8") as file:
931-
lines = file.read().split("\n")[3:-3]
933+
lines = _get_lines(filename)[3:-3] # type:ignore[arg-type,assignment]
932934
if len(lines) == 0:
933935
raise RuntimeError("CHARGES file contains no data.")
934936

@@ -1061,8 +1063,7 @@ def __init__(self, filename: PathLike | None, **kwargs) -> None:
10611063
else:
10621064
raise ValueError(f"{attr}={val} is not a valid attribute for Lobsterout")
10631065
elif filename:
1064-
with zopen(filename, mode="rt", encoding="utf-8") as file:
1065-
lines = file.read().split("\n")
1066+
lines = _get_lines(filename)
10661067
if len(lines) == 0:
10671068
raise RuntimeError("lobsterout does not contain any data")
10681069

@@ -1459,8 +1460,7 @@ def __init__(
14591460
raise ValueError("No FATBAND files in folder or given")
14601461

14611462
for fname in filenames:
1462-
with zopen(fname, mode="rt", encoding="utf-8") as file:
1463-
lines = file.read().split("\n")
1463+
lines = _get_lines(fname)
14641464

14651465
atom_names.append(os.path.split(fname)[1].split("_")[1].capitalize())
14661466
parameters = lines[0].split()
@@ -1493,8 +1493,7 @@ def __init__(
14931493
eigenvals: dict = {}
14941494
p_eigenvals: dict = {}
14951495
for ifilename, filename in enumerate(filenames):
1496-
with zopen(filename, mode="rt", encoding="utf-8") as file:
1497-
lines = file.read().split("\n")
1496+
lines = _get_lines(filename)
14981497

14991498
if ifilename == 0:
15001499
self.nbands = int(parameters[6])
@@ -1601,7 +1600,7 @@ def get_bandstructure(self) -> LobsterBandStructureSymmLine:
16011600
lattice=self.lattice,
16021601
efermi=self.efermi, # type: ignore[arg-type]
16031602
labels_dict=self.label_dict,
1604-
structure=self.structure,
1603+
structure=self.structure, # type:ignore[arg-type]
16051604
projections=self.p_eigenvals,
16061605
)
16071606

@@ -1641,8 +1640,7 @@ def __init__(
16411640
self.max_deviation = [] if max_deviation is None else max_deviation
16421641

16431642
if not self.band_overlaps_dict:
1644-
with zopen(filename, mode="rt", encoding="utf-8") as file:
1645-
lines = file.read().split("\n")
1643+
lines = _get_lines(filename)
16461644

16471645
spin_numbers = [0, 1] if lines[0].split()[-1] == "0" else [1, 2]
16481646

@@ -1786,8 +1784,7 @@ def __init__(
17861784
self.is_lcfo = is_lcfo
17871785
self.list_dict_grosspop = [] if list_dict_grosspop is None else list_dict_grosspop
17881786
if not self.list_dict_grosspop:
1789-
with zopen(filename, mode="rt", encoding="utf-8") as file:
1790-
lines = file.read().split("\n")
1787+
lines = _get_lines(filename)
17911788

17921789
# Read file to list of dict
17931790
small_dict: dict[str, Any] = {}
@@ -1916,8 +1913,7 @@ def _parse_file(
19161913
imaginary (list[float]): Imaginary parts of wave function.
19171914
distance (list[float]): Distances to the first point in wave function file.
19181915
"""
1919-
with zopen(filename, mode="rt", encoding="utf-8") as file:
1920-
lines = file.read().split("\n")
1916+
lines = _get_lines(filename)
19211917

19221918
points = []
19231919
distances = []
@@ -2086,8 +2082,7 @@ def __init__(
20862082
self.madelungenergies_mulliken = None if madelungenergies_mulliken is None else madelungenergies_mulliken
20872083

20882084
if self.ewald_splitting is None:
2089-
with zopen(filename, mode="rt", encoding="utf-8") as file:
2090-
lines = file.read().split("\n")[5]
2085+
lines = _get_lines(filename)[5]
20912086
if len(lines) == 0:
20922087
raise RuntimeError("MadelungEnergies file contains no data.")
20932088

@@ -2157,8 +2152,7 @@ def __init__(
21572152
self.madelungenergies_mulliken: list | float = madelungenergies_mulliken or []
21582153

21592154
if self.num_atoms is None:
2160-
with zopen(filename, mode="rt", encoding="utf-8") as file:
2161-
lines = file.read().split("\n")
2155+
lines = _get_lines(filename)
21622156
if len(lines) == 0:
21632157
raise RuntimeError("SitePotentials file contains no data.")
21642158

@@ -2311,7 +2305,7 @@ def __init__(
23112305

23122306
self._filename = str(filename)
23132307
with zopen(self._filename, mode="rt", encoding="utf-8") as file:
2314-
lines = file.readlines()
2308+
lines: list[str] = file.readlines() # type:ignore[assignment]
23152309
if len(lines) == 0:
23162310
raise RuntimeError("Please check provided input file, it seems to be empty")
23172311

@@ -2454,8 +2448,7 @@ def __init__(
24542448
self.rel_loewdin_pol_vector = {} if rel_loewdin_pol_vector is None else rel_loewdin_pol_vector
24552449

24562450
if not self.rel_loewdin_pol_vector and not self.rel_mulliken_pol_vector:
2457-
with zopen(filename, mode="rt", encoding="utf-8") as file:
2458-
lines = file.read().split("\n")
2451+
lines = _get_lines(filename)
24592452
if len(lines) == 0:
24602453
raise RuntimeError("Polarization file contains no data.")
24612454

@@ -2498,8 +2491,7 @@ def __init__(
24982491
self.bin_width = 0.0 if bin_width is None else bin_width
24992492

25002493
if not self.bwdf:
2501-
with zopen(filename, mode="rt", encoding="utf-8") as file:
2502-
lines = file.read().split("\n")
2494+
lines = _get_lines(filename)
25032495
if len(lines) == 0:
25042496
raise RuntimeError("BWDF file contains no data.")
25052497

src/pymatgen/io/vasp/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3855,7 +3855,7 @@ def write_spin(data_type: str) -> None:
38553855
if count % 5 != 0:
38563856
file.write(" " + "".join(lines) + " \n") # type:ignore[arg-type]
38573857

3858-
data = self.data_aug.get(data_type, [])
3858+
data = self.data_aug.get(data_type, []) if self.data_aug is not None else []
38593859
if isinstance(data, Iterable):
38603860
file.write("".join(data)) # type:ignore[arg-type]
38613861

0 commit comments

Comments
 (0)