Skip to content

Commit b7b8e81

Browse files
mypy / precommit
1 parent 02c0c4e commit b7b8e81

27 files changed

+102
-79
lines changed

.github/workflows/linting.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
python -m pip install types-requests
2626
- name: mypy
2727
run: |
28-
mypy --namespace-packages --explicit-package-bases pymatgen
28+
mypy pymatgen
2929
- name: black
3030
run: |
3131
black --version

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ repos:
5252
args:
5353
- --namespace-packages
5454
- --explicit-package-bases
55-
additional_dependencies: ['types-requests']
55+
additional_dependencies: ['types-requests','pydantic>=2.0.1']

pymatgen/io/validation/check_common_errors.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pymatgen.io.validation.common import SETTINGS, BaseValidator
99

1010
if TYPE_CHECKING:
11+
from collections.abc import Sequence
1112
from numpy.typing import ArrayLike
1213

1314
from pymatgen.io.validation.common import VaspFiles
@@ -151,8 +152,12 @@ def _check_positive_energy(self, vasp_files: VaspFiles, reasons: list[str], warn
151152
def _check_large_magmoms(self, vasp_files: VaspFiles, reasons: list[str], warnings: list[str]) -> None:
152153
# Check for excessively large final magnetic moments
153154

154-
if not vasp_files.outcar:
155-
warnings.append("MAGNETISM --> No OUTCAR file specified")
155+
if (
156+
not vasp_files.outcar
157+
or not vasp_files.outcar.magnetization
158+
or any(mag.get("tot") is None for mag in vasp_files.outcar.magnetization)
159+
):
160+
warnings.append("MAGNETISM --> No OUTCAR file specified or data missing.")
156161
return
157162

158163
cur_magmoms = [abs(mag["tot"]) for mag in vasp_files.outcar.magnetization]
@@ -221,7 +226,7 @@ class CheckStructureProperties(BaseValidator):
221226
)
222227

223228
@staticmethod
224-
def _has_frozen_degrees_of_freedom(selective_dynamics_array: ArrayLike[bool] | None) -> bool:
229+
def _has_frozen_degrees_of_freedom(selective_dynamics_array: Sequence[bool] | None) -> bool:
225230
"""Check selective dynamics array for False values."""
226231
if selective_dynamics_array is None:
227232
return False
@@ -243,7 +248,7 @@ def _check_selective_dynamics(self, vasp_files: VaspFiles, reasons: list[str], w
243248
def _has_nonzero_velocities(velocities: ArrayLike | None, tol: float = 1.0e-8) -> bool:
244249
if velocities is None:
245250
return False
246-
return np.any(np.abs(velocities) > tol)
251+
return np.any(np.abs(velocities) > tol) # type: ignore [return-value]
247252

248253
def _check_velocities(self, vasp_files: VaspFiles, reasons: list[str], warnings: list[str]) -> None:
249254
"""Check structure for non-zero velocities."""

pymatgen/io/validation/check_incar.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,7 @@ def _update_misc_special_params(self, user_incar: dict, ref_incar: dict, vasp_fi
236236

237237
if (
238238
user_incar["ISPIN"] == 2
239-
and vasp_files.outcar
240-
and len(vasp_files.outcar.magnetization) != vasp_files.user_input.structure.num_sites
239+
and len(getattr(vasp_files.outcar, "magnetization", [])) != vasp_files.user_input.structure.num_sites
241240
):
242241
self.vasp_defaults["LORBIT"].update(
243242
{
@@ -308,18 +307,19 @@ def _update_fft_params(self, user_incar: dict, ref_incar: dict, vasp_files: Vasp
308307
enmaxs = [user_incar["ENMAX"], ref_incar["ENMAX"]]
309308
ref_incar["ENMAX"] = max([v for v in enmaxs if v < float("inf")])
310309

311-
(
312-
[
313-
ref_incar["NGX"],
314-
ref_incar["NGY"],
315-
ref_incar["NGZ"],
316-
],
317-
[
318-
ref_incar["NGXF"],
319-
ref_incar["NGYF"],
320-
ref_incar["NGZF"],
321-
],
322-
) = vasp_files.valid_input_set._calculate_ng(custom_encut=ref_incar["ENMAX"])
310+
if fft_grid := vasp_files.valid_input_set._calculate_ng(custom_encut=ref_incar["ENMAX"]):
311+
(
312+
[
313+
ref_incar["NGX"],
314+
ref_incar["NGY"],
315+
ref_incar["NGZ"],
316+
],
317+
[
318+
ref_incar["NGXF"],
319+
ref_incar["NGYF"],
320+
ref_incar["NGZF"],
321+
],
322+
) = fft_grid
323323

324324
for key in grid_keys:
325325
ref_incar[key] = int(ref_incar[key] * self.fft_grid_tolerance)
@@ -490,7 +490,7 @@ def _update_electronic_params(self, user_incar: dict, ref_incar: dict, vasp_file
490490

491491
# ENAUG. Should only be checked for calculations where the relevant MP input set specifies ENAUG.
492492
# In that case, ENAUG should be the same or greater than in valid_input_set.
493-
if ref_incar.get("ENAUG") < float("inf"):
493+
if ref_incar.get("ENAUG") and not np.isinf(ref_incar["ENAUG"]):
494494
self.vasp_defaults["ENAUG"].operation = ">="
495495

496496
# IALGO.
@@ -502,7 +502,7 @@ def _update_electronic_params(self, user_incar: dict, ref_incar: dict, vasp_file
502502
if vasp_files.vasprun and (nelect := vasp_files.vasprun.parameters.get("NELECT")):
503503
ref_incar["NELECT"] = 0.0
504504
try:
505-
user_incar["NELECT"] = float(vasp_files.vasprun.final_structure._charge)
505+
user_incar["NELECT"] = float(vasp_files.vasprun.final_structure._charge or 0.0)
506506
self.vasp_defaults["NELECT"].operation = "approx"
507507
self.vasp_defaults["NELECT"].comment = (
508508
f"This causes the structure to have a charge of {user_incar['NELECT']}. "

pymatgen/io/validation/check_kpoints_kspacing.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class CheckKpointsKspacing(BaseValidator):
1515
"""Check that k-point density is sufficiently high and is compatible with lattice symmetry."""
1616

1717
name: str = "Check k-point density"
18-
kpts_tolerance: float | None = Field(
18+
kpts_tolerance: float = Field(
1919
SETTINGS.VASP_KPTS_TOLERANCE,
2020
description="Tolerance for evaluating k-point density, to accommodate different the k-point generation schemes across VASP versions.",
2121
)
@@ -46,16 +46,18 @@ def _get_valid_num_kpts(
4646
int, the minimum permitted number of k-points, consistent with self.kpts_tolerance
4747
"""
4848
# If MP input set specifies KSPACING in the INCAR
49-
if ("KSPACING" in vasp_files.valid_input_set.incar.keys()) and (vasp_files.valid_input_set.kpoints is None):
50-
valid_kspacing = vasp_files.valid_input_set.incar.get("KSPACING", self.vasp_defaults["KSPACING"].value)
49+
if (kspacing := vasp_files.valid_input_set.incar.get("KSPACING")) and (
50+
vasp_files.valid_input_set.kpoints is None
51+
):
52+
valid_kspacing = kspacing
5153
# number of kpoints along each of the three lattice vectors
5254
nk = [
5355
max(1, np.ceil(vasp_files.user_input.structure.lattice.reciprocal_lattice.abc[ik] / valid_kspacing))
5456
for ik in range(3)
5557
]
5658
valid_num_kpts = np.prod(nk)
5759
# If MP input set specifies a KPOINTS file
58-
else:
60+
elif vasp_files.valid_input_set.kpoints:
5961
valid_num_kpts = vasp_files.valid_input_set.kpoints.num_kpts or np.prod(
6062
vasp_files.valid_input_set.kpoints.kpts[0]
6163
)
@@ -64,7 +66,7 @@ def _get_valid_num_kpts(
6466

6567
def _check_user_shifted_mesh(self, vasp_files: VaspFiles, reasons: list[str], warnings: list[str]) -> None:
6668
# Check for user shifts
67-
if (not self.allow_kpoint_shifts) and any(shift_val != 0 for shift_val in vasp_files.actual_kpoints.kpts_shift):
69+
if (not self.allow_kpoint_shifts) and any(shift_val != 0 for shift_val in vasp_files.actual_kpoints.kpts_shift): # type: ignore[union-attr]
6870
reasons.append("INPUT SETTINGS --> KPOINTS: shifting the kpoint mesh is not currently allowed.")
6971

7072
def _check_explicit_mesh_permitted(self, vasp_files: VaspFiles, reasons: list[str], warnings: list[str]) -> None:
@@ -77,7 +79,7 @@ def _check_explicit_mesh_permitted(self, vasp_files: VaspFiles, reasons: list[st
7779
else:
7880
allow_explicit = False
7981

80-
if (not allow_explicit) and len(vasp_files.actual_kpoints.kpts) > 1:
82+
if (not allow_explicit) and len(vasp_files.actual_kpoints.kpts) > 1: # type: ignore[union-attr]
8183
reasons.append(
8284
"INPUT SETTINGS --> KPOINTS: explicitly defining "
8385
"the k-point mesh is not currently allowed. "
@@ -92,10 +94,10 @@ def _check_kpoint_density(self, vasp_files: VaspFiles, reasons: list[str], warni
9294
# Check number of kpoints used
9395
valid_num_kpts = self._get_valid_num_kpts(vasp_files)
9496

95-
cur_num_kpts = max(
96-
vasp_files.actual_kpoints.num_kpts,
97-
np.prod(vasp_files.actual_kpoints.kpts),
98-
len(vasp_files.actual_kpoints.kpts),
97+
cur_num_kpts: int = max( # type: ignore[assignment]
98+
vasp_files.actual_kpoints.num_kpts, # type: ignore[union-attr]
99+
np.prod(vasp_files.actual_kpoints.kpts), # type: ignore[union-attr]
100+
len(vasp_files.actual_kpoints.kpts), # type: ignore[union-attr]
99101
)
100102
if cur_num_kpts < valid_num_kpts:
101103
reasons.append(
@@ -106,19 +108,19 @@ def _check_kpoint_density(self, vasp_files: VaspFiles, reasons: list[str], warni
106108
def _check_kpoint_mesh_symmetry(self, vasp_files: VaspFiles, reasons: list[str], warnings: list[str]) -> None:
107109
# check for valid kpoint mesh (which depends on symmetry of the structure)
108110

109-
cur_kpoint_style = vasp_files.actual_kpoints.style.name.lower()
111+
cur_kpoint_style = vasp_files.actual_kpoints.style.name.lower() # type: ignore[union-attr]
110112
is_hexagonal = vasp_files.user_input.structure.lattice.is_hexagonal()
111113
is_face_centered = vasp_files.user_input.structure.get_space_group_info()[0][0] == "F"
112114
monkhorst_mesh_is_invalid = is_hexagonal or is_face_centered
113115
if (
114116
cur_kpoint_style == "monkhorst"
115117
and monkhorst_mesh_is_invalid
116-
and any(x % 2 == 0 for x in vasp_files.actual_kpoints.kpts[0])
118+
and any(x % 2 == 0 for x in vasp_files.actual_kpoints.kpts[0]) # type: ignore[union-attr]
117119
):
118120
# only allow Monkhorst with all odd number of subdivisions per axis.
119-
kx, ky, kz = vasp_files.actual_kpoints.kpts[0]
121+
kv = vasp_files.actual_kpoints.kpts[0] # type: ignore[union-attr]
120122
reasons.append(
121-
f"INPUT SETTINGS --> KPOINTS or KGAMMA: ({kx}x{ky}x{kz}) "
123+
f"INPUT SETTINGS --> KPOINTS or KGAMMA: ({'×'.join([f'{_k}' for _k in kv])}) "
122124
"Monkhorst-Pack kpoint mesh was used."
123125
"To be compatible with the symmetry of the lattice, "
124126
"a Monkhorst-Pack mesh should have only odd number of "

pymatgen/io/validation/check_package_versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44
from importlib.metadata import version
5-
import requests
5+
import requests # type: ignore[import-untyped]
66
import warnings
77

88

pymatgen/io/validation/check_potcar.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
from pymatgen.io.vasp import PotcarSingle
1313

14-
from pymatgen.io.validation.common import BaseValidator
14+
from pymatgen.io.validation.common import BaseValidator, ValidationError
1515

1616
if TYPE_CHECKING:
17+
from typing import Any
1718
from pymatgen.io.validation.common import VaspFiles
1819

1920

@@ -24,7 +25,7 @@ class CheckPotcar(BaseValidator):
2425

2526
name: str = "Check POTCAR"
2627
potcar_summary_stats_path: str | Path | None = Field(
27-
import_resource_files("pymatgen.io.vasp") / "potcar-summary-stats.json.bz2",
28+
str(import_resource_files("pymatgen.io.vasp") / "potcar-summary-stats.json.bz2"),
2829
description="Path to potcar summary data. Mapping is calculation type -> potcar symbol -> summary data.",
2930
)
3031
data_match_tol: float = Field(1.0e-6, description="Tolerance for matching POTCARs to summary statistics data.")
@@ -33,11 +34,11 @@ class CheckPotcar(BaseValidator):
3334
)
3435

3536
@cached_property
36-
def potcar_summary_stats(self) -> dict | None:
37+
def potcar_summary_stats(self) -> dict:
3738
"""Load POTCAR summary statistics file."""
3839
if self.potcar_summary_stats_path:
3940
return loadfn(self.potcar_summary_stats_path, cls=None)
40-
return None
41+
return {}
4142

4243
def auto_fail(self, vasp_files: VaspFiles, reasons: list[str], warnings: list[str]) -> bool:
4344
"""Skip if no POTCAR was provided, or if summary stats file was unset."""
@@ -63,26 +64,28 @@ def _check_potcar_spec(self, vasp_files: VaspFiles, reasons: list[str], warnings
6364

6465
if vasp_files.valid_input_set.potcar:
6566
# If the user has pymatgen set up, use the pregenerated POTCAR summary stats.
66-
valid_potcar_summary_stats = {
67+
valid_potcar_summary_stats: dict[str, list[dict[str, Any]]] = {
6768
p.titel.replace(" ", ""): [p.model_dump()] for p in vasp_files.valid_input_set.potcar
6869
}
69-
else:
70+
elif vasp_files.valid_input_set._pmg_vis:
7071
# Fallback, use the stats from pymatgen - only load and cache summary stats here.
71-
psp_subset = self.potcar_summary_stats.get(vasp_files.valid_input_set._config_dict["POTCAR_FUNCTIONAL"], {})
72+
psp_subset = self.potcar_summary_stats.get(vasp_files.valid_input_set.potcar_functional, {})
7273

73-
valid_potcar_summary_stats = {} # type: ignore
74+
valid_potcar_summary_stats = {}
7475
for element in vasp_files.user_input.structure.composition.remove_charges().as_dict():
75-
potcar_symbol = vasp_files.valid_input_set._config_dict["POTCAR"][element]
76+
potcar_symbol = vasp_files.valid_input_set._pmg_vis._config_dict["POTCAR"][element]
7677
for titel_no_spc in psp_subset:
7778
for psp in psp_subset[titel_no_spc]:
7879
if psp["symbol"] == potcar_symbol:
7980
if titel_no_spc not in valid_potcar_summary_stats:
8081
valid_potcar_summary_stats[titel_no_spc] = []
8182
valid_potcar_summary_stats[titel_no_spc].append(psp)
83+
else:
84+
raise ValidationError("Could not determine reference POTCARs.")
8285

8386
try:
84-
incorrect_potcars = []
85-
for potcar in vasp_files.user_input.potcar:
87+
incorrect_potcars: list[str] = []
88+
for potcar in vasp_files.user_input.potcar: # type: ignore[union-attr]
8689
reference_summary_stats = valid_potcar_summary_stats.get(potcar.titel.replace(" ", ""), [])
8790
potcar_symbol = potcar.titel.split(" ")[1]
8891

@@ -94,7 +97,7 @@ def _check_potcar_spec(self, vasp_files: VaspFiles, reasons: list[str], warnings
9497
user_summary_stats = potcar.model_dump()
9598
ref_psp = deepcopy(_ref_psp)
9699
for _set in (user_summary_stats, ref_psp):
97-
_set["keywords"]["header"] = set(_set["keywords"]["header"]).difference(self.ignore_header_keys)
100+
_set["keywords"]["header"] = set(_set["keywords"]["header"]).difference(self.ignore_header_keys) # type: ignore[arg-type]
98101
if found_match := PotcarSingle.compare_potcar_stats(
99102
ref_psp, user_summary_stats, tolerance=self.data_match_tol
100103
):

pymatgen/io/validation/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,12 @@ class _PotcarSummaryStatsNames(BaseModel):
6161
lexch: str
6262

6363
@classmethod
64-
def from_file(cls, potcar: os.PathLike | Potcar) -> list[Self]:
64+
def from_file(cls, potcar_path: os.PathLike | Potcar) -> list[Self]:
6565
"""Create a list of PotcarSummaryStats from a POTCAR."""
66-
if not isinstance(potcar, Potcar):
67-
potcar = Potcar.from_file(potcar)
66+
if isinstance(potcar_path, Potcar):
67+
potcar: Potcar = potcar_path
68+
else:
69+
potcar = Potcar.from_file(str(potcar_path))
6870
return [cls(**p._summary_stats, titel=p.TITEL, lexch=p.LEXCH) for p in potcar]
6971

7072

@@ -107,12 +109,15 @@ class VaspInputSafe(BaseModel):
107109
structure: Structure = Field(description="The structure associated with the calculation.")
108110
kpoints: Kpoints | None = Field(None, description="The optional KPOINTS or IBZKPT file used in the calculation.")
109111
potcar: list[PotcarSummaryStats] | None = Field(None, description="The optional POTCAR used in the calculation.")
112+
potcar_functional: str | None = Field(None, description="The pymatgen-labelled POTCAR library release.")
110113
_pmg_vis: VaspInputSet | None = PrivateAttr(None)
111114

112115
@model_serializer
113116
def deserialize_objects(self) -> dict[str, Any]:
114117
"""Ensure all pymatgen objects are deserialized."""
115-
model_dumped: dict[str, Any] = {"potcar": [p.model_dump() for p in self.potcar]}
118+
model_dumped: dict[str, Any] = {}
119+
if self.potcar:
120+
model_dumped["potcar"] = [p.model_dump() for p in self.potcar]
116121
for k in (
117122
"incar",
118123
"structure",
@@ -134,6 +139,7 @@ def from_vasp_input_set(cls, vis: VaspInputSet) -> Self:
134139
)
135140
},
136141
potcar=PotcarSummaryStats.from_file(vis.potcar),
142+
potcar_functional=vis.potcar_functional,
137143
)
138144
new_vis._pmg_vis = vis
139145
return new_vis
@@ -194,7 +200,7 @@ def from_paths(
194200
for file_name, file_cls in to_obj.items():
195201
if (path := _vars.get(file_name)) and Path(path).exists():
196202
if file_name == "poscar":
197-
config["user_input"]["structure"] = file_cls.from_file(path).structure
203+
config["user_input"]["structure"] = Poscar.from_file(path).structure
198204
elif hasattr(file_cls, "from_file"):
199205
config["user_input"][file_name] = file_cls.from_file(path)
200206
else:
@@ -208,11 +214,11 @@ def from_paths(
208214
drift=config["outcar"].drift,
209215
magnetization=config["outcar"].magnetization,
210216
)
217+
211218
if config.get("vasprun"):
212219
config["vasprun"] = LightVasprun.from_vasprun(config["vasprun"])
213-
else:
214-
if not config["incar"].get("ENCUT") and potcar_enmax:
215-
config["incar"]["ENCUT"] = potcar_enmax
220+
elif not config["user_input"]["incar"].get("ENCUT") and potcar_enmax:
221+
config["user_input"]["incar"]["ENCUT"] = potcar_enmax
216222

217223
return cls(**config)
218224

pymatgen/io/validation/compare_to_MP_ehull.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Module for checking if a structure's energy is within a certain distance of the MPDB hull"""
22

3-
from mp_api.client import MPRester
3+
from mp_api.client import MPRester # type: ignore[import-untyped]
44
from pymatgen.analysis.phase_diagram import PhaseDiagram
55
from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme
66
from pymatgen.entries.computed_entries import ComputedStructureEntry

pymatgen/io/validation/emmet_validation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from datetime import datetime
66
from pydantic import Field
77

8-
from emmet.core.tasks import TaskDoc
9-
from emmet.core.vasp.calculation import Calculation
10-
from emmet.core.vasp.task_valid import TaskDocument
11-
from emmet.core.base import EmmetBaseModel
12-
from emmet.core.mpid import MPID
13-
from emmet.core.utils import utcnow
8+
from emmet.core.tasks import TaskDoc # type: ignore[import-untyped]
9+
from emmet.core.vasp.calculation import Calculation # type: ignore[import-untyped]
10+
from emmet.core.vasp.task_valid import TaskDocument # type: ignore[import-untyped]
11+
from emmet.core.base import EmmetBaseModel # type: ignore[import-untyped]
12+
from emmet.core.mpid import MPID # type: ignore[import-untyped]
13+
from emmet.core.utils import utcnow # type: ignore[import-untyped]
1414

1515
from pymatgen.io.vasp import Incar
1616

0 commit comments

Comments
 (0)