|
8 | 8 | import os |
9 | 9 | import numpy as np |
10 | 10 | from pathlib import Path |
11 | | -from pydantic import BaseModel, Field, model_serializer, PrivateAttr |
| 11 | +from pydantic import BaseModel, Field, model_validator, model_serializer, PrivateAttr |
12 | 12 | from typing import TYPE_CHECKING, Any, Optional |
13 | 13 |
|
14 | 14 | from pymatgen.core import Structure |
@@ -199,6 +199,20 @@ class VaspFiles(BaseModel): |
199 | 199 | outcar: Optional[LightOutcar] = None |
200 | 200 | vasprun: Optional[LightVasprun] = None |
201 | 201 |
|
| 202 | + @model_validator(mode="before") |
| 203 | + @classmethod |
| 204 | + def coerce_to_lightweight(cls, config: Any) -> Any: |
| 205 | + """Ensure that pymatgen objects are converted to minimal representations.""" |
| 206 | + if isinstance(config.get("outcar"), Outcar): |
| 207 | + config["outcar"] = LightOutcar( |
| 208 | + drift=config["outcar"].drift, |
| 209 | + magnetization=config["outcar"].magnetization, |
| 210 | + ) |
| 211 | + |
| 212 | + if isinstance(config.get("vasprun"), Vasprun): |
| 213 | + config["vasprun"] = LightVasprun.from_vasprun(config["vasprun"]) |
| 214 | + return config |
| 215 | + |
202 | 216 | @property |
203 | 217 | def md5(self) -> str: |
204 | 218 | """Get MD5 of VaspFiles for use in validation checks.""" |
@@ -256,15 +270,7 @@ def from_paths( |
256 | 270 | if file_name == "potcar": |
257 | 271 | potcar_enmax = max(ps.ENMAX for ps in Potcar.from_file(path)) |
258 | 272 |
|
259 | | - if config.get("outcar"): |
260 | | - config["outcar"] = LightOutcar( |
261 | | - drift=config["outcar"].drift, |
262 | | - magnetization=config["outcar"].magnetization, |
263 | | - ) |
264 | | - |
265 | | - if config.get("vasprun"): |
266 | | - config["vasprun"] = LightVasprun.from_vasprun(config["vasprun"]) |
267 | | - elif not config["user_input"]["incar"].get("ENCUT") and potcar_enmax: |
| 273 | + if not config.get("vasprun") and not config["user_input"]["incar"].get("ENCUT") and potcar_enmax: |
268 | 274 | config["user_input"]["incar"]["ENCUT"] = potcar_enmax |
269 | 275 |
|
270 | 276 | return cls(**config) |
|
0 commit comments