Skip to content

Commit 8ba4b6e

Browse files
authored
Handle numpy array for selective dynamics in Structure (#4461)
* enable orjson.OPT_SERIALIZE_NUMPY * np.array -> np.asarray * remove tolist * safer deprecation replacement * rename * better deprecation * put structure related ops together * add test * str -> PathLike for filename * clean up * also enable OPT_SERIALIZE_NUMPY for molecule * overwrite default grou * ignore arg-type, fix needed from monty * Revert "overwrite default grou" This reverts commit 409d954.
1 parent 4201994 commit 8ba4b6e

File tree

6 files changed

+68
-28
lines changed

6 files changed

+68
-28
lines changed

src/pymatgen/core/structure.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def is_valid(self, tol: float = DISTANCE_TOLERANCE) -> bool:
520520
return bool(np.min(all_dists) > tol)
521521

522522
@abstractmethod
523-
def to(self, filename: str = "", fmt: FileFormats = "") -> str | None:
523+
def to(self, filename: PathLike = "", fmt: FileFormats = "") -> str | None:
524524
"""Generate string representations (cif, json, poscar, ....) of SiteCollections (e.g.,
525525
molecules / structures). Should return str or None if written to a file.
526526
"""
@@ -564,6 +564,7 @@ def add_site_property(self, property_name: str, values: Sequence | NDArray) -> S
564564
"""
565565
if len(values) != len(self):
566566
raise ValueError(f"{len(values)=} must equal sites in structure={len(self)}")
567+
567568
for site, val in zip(self, values, strict=True):
568569
site.properties[property_name] = val
569570

@@ -2964,7 +2965,11 @@ def to(self, filename: PathLike = "", fmt: FileFormats = "", **kwargs) -> str:
29642965
writer = Cssr(self)
29652966

29662967
elif fmt == "json" or fnmatch(filename.lower(), "*.json*"):
2967-
json_str = json.dumps(self.as_dict(), **kwargs) if kwargs else orjson.dumps(self.as_dict()).decode()
2968+
json_str = (
2969+
json.dumps(self.as_dict(), **kwargs)
2970+
if kwargs
2971+
else orjson.dumps(self.as_dict(), option=orjson.OPT_SERIALIZE_NUMPY).decode()
2972+
)
29682973

29692974
if filename:
29702975
with zopen(filename, mode="wt", encoding="utf-8") as file:
@@ -3989,11 +3994,11 @@ def get_centered_molecule(self) -> Self:
39893994
properties=self.properties,
39903995
)
39913996

3992-
def to(self, filename: str = "", fmt: str = "") -> str | None:
3997+
def to(self, filename: PathLike = "", fmt: str = "") -> str | None:
39933998
"""Outputs the molecule to a file or string.
39943999
39954000
Args:
3996-
filename (str): If provided, output will be written to a file. If
4001+
filename (PathLike): If provided, output will be written to a file. If
39974002
fmt is not specified, the format is determined from the
39984003
filename. Defaults is None, i.e. string output.
39994004
fmt (str): Format to output to. Defaults to JSON unless filename
@@ -4006,22 +4011,28 @@ def to(self, filename: str = "", fmt: str = "") -> str | None:
40064011
str: String representation of molecule in given format. If a filename
40074012
is provided, the same string is written to the file.
40084013
"""
4014+
filename = str(filename)
40094015
fmt = fmt.lower()
4016+
40104017
writer: Any
40114018
if fmt == "xyz" or fnmatch(filename.lower(), "*.xyz*"):
40124019
from pymatgen.io.xyz import XYZ
40134020

40144021
writer = XYZ(self)
4022+
40154023
elif any(fmt == ext or fnmatch(filename.lower(), f"*.{ext}*") for ext in ("gjf", "g03", "g09", "com", "inp")):
40164024
from pymatgen.io.gaussian import GaussianInput
40174025

40184026
writer = GaussianInput(self)
4027+
40194028
elif fmt == "json" or fnmatch(filename, "*.json*") or fnmatch(filename, "*.mson*"):
4020-
json_str = orjson.dumps(self.as_dict()).decode()
4029+
json_str = orjson.dumps(self.as_dict(), option=orjson.OPT_SERIALIZE_NUMPY).decode()
4030+
40214031
if filename:
40224032
with zopen(filename, mode="wt", encoding="utf-8") as file:
40234033
file.write(json_str) # type:ignore[arg-type]
40244034
return json_str
4035+
40254036
elif fmt in {"yaml", "yml"} or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"):
40264037
yaml = YAML()
40274038
str_io = io.StringIO()
@@ -4031,6 +4042,7 @@ def to(self, filename: str = "", fmt: str = "") -> str | None:
40314042
with zopen(filename, mode="wt", encoding="utf-8") as file:
40324043
file.write(yaml_str) # type:ignore[arg-type]
40334044
return yaml_str
4045+
40344046
else:
40354047
from pymatgen.io.babel import BabelMolAdaptor
40364048

@@ -4042,6 +4054,7 @@ def to(self, filename: str = "", fmt: str = "") -> str | None:
40424054

40434055
if filename:
40444056
writer.write_file(filename)
4057+
40454058
return str(writer)
40464059

40474060
@classmethod
@@ -4109,28 +4122,35 @@ def from_file(cls, filename: PathLike) -> IMolecule | Molecule: # type:ignore[o
41094122
Molecule
41104123
"""
41114124
filename = str(filename)
4125+
fname = filename.lower()
41124126

41134127
with zopen(filename, mode="rt", encoding="utf-8") as file:
4114-
contents: str = file.read() # type:ignore[assignment]
4115-
fname = filename.lower()
4128+
contents: str = cast("str", file.read())
4129+
41164130
if fnmatch(fname, "*.xyz*"):
41174131
return cls.from_str(contents, fmt="xyz")
4132+
41184133
if any(fnmatch(fname.lower(), f"*.{r}*") for r in ("gjf", "g03", "g09", "com", "inp")):
41194134
return cls.from_str(contents, fmt="g09")
4135+
41204136
if any(fnmatch(fname.lower(), f"*.{r}*") for r in ("out", "lis", "log")):
41214137
from pymatgen.io.gaussian import GaussianOutput
41224138

41234139
return GaussianOutput(filename).final_structure
4140+
41244141
if fnmatch(fname, "*.json*") or fnmatch(fname, "*.mson*"):
41254142
return cls.from_str(contents, fmt="json")
4143+
41264144
if fnmatch(fname, "*.yaml*") or fnmatch(filename, "*.yml*"):
41274145
return cls.from_str(contents, fmt="yaml")
4128-
from pymatgen.io.babel import BabelMolAdaptor
41294146

41304147
if match := re.search(r"\.(pdb|mol|mdl|sdf|sd|ml2|sy2|mol2|cml|mrv)", filename.lower()):
4148+
from pymatgen.io.babel import BabelMolAdaptor
4149+
41314150
new = BabelMolAdaptor.from_file(filename, match[1]).pymatgen_mol
41324151
new.__class__ = cls
41334152
return new
4153+
41344154
raise ValueError("Cannot determine file type.")
41354155

41364156

src/pymatgen/io/abinit/abitimer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,8 @@ def as_dict(self):
628628
return {at: self.__dict__[at] for at in AbinitTimerSection.FIELDS}
629629

630630
@deprecated(as_dict, deadline=(2026, 4, 4))
631-
def to_dict(self):
632-
return self.as_dict()
631+
def to_dict(self, *args, **kwargs):
632+
return self.as_dict(*args, **kwargs)
633633

634634
def to_csvline(self, with_header=False):
635635
"""Return a string with data in CSV format. Add header if `with_header`."""

src/pymatgen/io/aims/inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,8 +935,8 @@ def as_dict(self):
935935
}
936936

937937
@deprecated(replacement=as_dict, deadline=(2026, 4, 4))
938-
def to_dict(self):
939-
return self.as_dict()
938+
def to_dict(self, *args, **kwargs):
939+
return self.as_dict(*args, **kwargs)
940940

941941
@classmethod
942942
def from_dict(cls, dct: dict[str, Any]) -> SpeciesDefaults:

src/pymatgen/io/vasp/inputs.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import numpy as np
2626
import orjson
2727
import scipy.constants as const
28+
from monty.dev import deprecated
2829
from monty.io import zopen
2930
from monty.json import MontyDecoder, MSONable
3031
from monty.os import cd
@@ -141,41 +142,41 @@ def __init__(
141142
site_properties: dict[str, Any] = {}
142143

143144
if selective_dynamics is not None:
144-
selective_dynamics = np.array(selective_dynamics)
145+
selective_dynamics = np.asarray(selective_dynamics)
145146
if not selective_dynamics.all():
146147
site_properties["selective_dynamics"] = selective_dynamics
147148

148149
if velocities:
149-
velocities = np.array(velocities)
150+
velocities = np.asarray(velocities)
150151
if velocities.any():
151152
site_properties["velocities"] = velocities
152153

153154
if predictor_corrector:
154-
predictor_corrector = np.array(predictor_corrector)
155+
predictor_corrector = np.asarray(predictor_corrector)
155156
if predictor_corrector.any():
156157
site_properties["predictor_corrector"] = predictor_corrector
157158

158159
structure = Structure.from_sites(structure)
159160
self.structure = structure.copy(site_properties=site_properties)
160161
if sort_structure:
161162
self.structure = self.structure.get_sorted_structure()
162-
self.true_names = true_names
163-
self.comment = structure.formula if comment is None else comment
163+
164164
if predictor_corrector_preamble:
165165
self.structure.properties["predictor_corrector_preamble"] = predictor_corrector_preamble
166166

167167
if lattice_velocities and np.any(lattice_velocities):
168168
self.structure.properties["lattice_velocities"] = np.asarray(lattice_velocities)
169169

170-
self.temperature = -1.0
170+
self.true_names: bool = true_names
171+
self.comment: str = structure.formula if comment is None else comment
172+
self.temperature: float = -1.0
171173

172174
def __setattr__(self, name: str, value: Any) -> None:
173175
if name in {"selective_dynamics", "velocities"} and value is not None and len(value) > 0:
174-
value = np.array(value)
176+
value = np.asarray(value)
175177
dim = value.shape
176178
if dim[1] != 3 or dim[0] != len(self.structure):
177179
raise ValueError(f"{name} array must be same length as the structure.")
178-
value = value.tolist()
179180

180181
super().__setattr__(name, value)
181182

@@ -687,7 +688,7 @@ def get_str(
687688
lines.append("")
688689
if self.predictor_corrector_preamble:
689690
lines.append(self.predictor_corrector_preamble)
690-
pred = np.array(self.predictor_corrector)
691+
pred = np.asarray(self.predictor_corrector)
691692
for col in range(3):
692693
for z in pred[:, col]:
693694
lines.append(" ".join(format_str.format(i) for i in z))
@@ -700,7 +701,9 @@ def get_str(
700701

701702
return "\n".join(lines) + "\n"
702703

703-
get_string = get_str
704+
@deprecated(get_str)
705+
def get_string(self, *args, **kwargs):
706+
return self.get_str(*args, **kwargs)
704707

705708
def write_file(self, filename: PathLike, **kwargs) -> None:
706709
"""Write POSCAR to a file. The supported kwargs are the same as those for
@@ -716,7 +719,7 @@ def as_dict(self) -> dict:
716719
"@class": type(self).__name__,
717720
"structure": self.structure.as_dict(),
718721
"true_names": self.true_names,
719-
"selective_dynamics": np.array(self.selective_dynamics).tolist(),
722+
"selective_dynamics": np.asarray(self.selective_dynamics).tolist(),
720723
"velocities": self.velocities,
721724
"predictor_corrector": self.predictor_corrector,
722725
"comment": self.comment,
@@ -2432,7 +2435,7 @@ def parse_fortran_style_str(input_str: str) -> str | bool | float | int:
24322435

24332436
def data_stats(data_list: Sequence) -> dict:
24342437
"""Used for hash-less and therefore less brittle POTCAR validity checking."""
2435-
arr = np.array(data_list)
2438+
arr = np.asarray(data_list)
24362439
return {
24372440
"MEAN": np.mean(arr),
24382441
"ABSMEAN": np.mean(np.abs(arr)),
@@ -2622,7 +2625,7 @@ def compare_potcar_stats(
26222625
if key_match:
26232626
data_diff = [
26242627
abs(potcar_stats_1["stats"].get(key, {}).get(stat) - potcar_stats_2["stats"].get(key, {}).get(stat))
2625-
for stat in ["MEAN", "ABSMEAN", "VAR", "MIN", "MAX"]
2628+
for stat in ("MEAN", "ABSMEAN", "VAR", "MIN", "MAX")
26262629
for key in check_potcar_fields
26272630
]
26282631
data_match = all(np.array(data_diff) < tolerance)

src/pymatgen/io/vasp/sets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,13 +488,13 @@ def get_input_set(
488488
)
489489

490490
@deprecated(get_input_set, deadline=(2026, 6, 6))
491-
def get_vasp_input(self, structure: Structure | None = None) -> VaspInput:
491+
def get_vasp_input(self, *args, **kwargs) -> VaspInput:
492492
"""Get a VaspInput object.
493493
494494
Returns:
495495
VaspInput.
496496
"""
497-
return self.get_input_set(structure=structure)
497+
return self.get_input_set(*args, **kwargs)
498498

499499
@property
500500
def incar_updates(self) -> dict:

tests/core/test_structure.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1252,7 +1252,7 @@ def test_add_remove_site_property(self):
12521252
with pytest.raises(AttributeError, match="attr='magmom' not found on PeriodicSite"):
12531253
_ = struct[0].magmom
12541254

1255-
def test_propertied_structure(self):
1255+
def test_site_properties(self):
12561256
# Make sure that site properties are set to None for missing values.
12571257
self.struct.add_site_property("charge", [4.1, -5])
12581258
self.struct.append("Li", [0.3, 0.3, 0.3])
@@ -1278,6 +1278,23 @@ def test_propertied_structure(self):
12781278
assert struct.properties == props
12791279
assert dct == struct.as_dict()
12801280

1281+
def test_selective_dynamics(self):
1282+
"""Ensure selective dynamics as numpy arrays can be JSON serialized."""
1283+
struct = self.get_structure("Li2O")
1284+
struct.add_site_property(
1285+
"selective_dynamics", np.array([[True, True, True], [False, False, False], [True, True, True]])
1286+
)
1287+
1288+
orjson_str = struct.to(fmt="json")
1289+
1290+
# Also test round trip
1291+
orjson_struct = Structure.from_str(orjson_str, fmt="json")
1292+
assert struct == orjson_struct
1293+
1294+
with pytest.raises(TypeError, match="Object of type ndarray is not JSON serializable"):
1295+
# Use a dummy kwarg (default value) to force `json.dumps`
1296+
struct.to(fmt="json", ensure_ascii=True)
1297+
12811298
def test_perturb(self):
12821299
struct = self.get_structure("Li2O")
12831300
struct_orig = struct.copy()

0 commit comments

Comments
 (0)