Skip to content

Commit c14d67e

Browse files
authored
Fix ValueError: Invalid fmt with Structure.to(fmt='yml') (#3557)
* fix Structure.to(fmt='yml'), add Structure.FileFormats to ensure consistent format support and value err msg in Structure.to and Structure.from_file * test_structure.py cover fmt='yml' in test_to_from_file_string fix pytest.raises expected msg
1 parent 88921da commit c14d67e

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

pymatgen/core/structure.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060

6161
from pymatgen.util.typing import CompositionLike, SpeciesLike
6262

63+
FileFormats = Literal["cif", "poscar", "cssr", "json", "yaml", "yml", "xsf", "mcsqs", "res", ""]
64+
6365

6466
class Neighbor(Site):
6567
"""Simple Site subclass to contain a neighboring atom that skips all the unnecessary checks for speed. Can be
@@ -447,13 +449,13 @@ def is_valid(self, tol: float = DISTANCE_TOLERANCE) -> bool:
447449
return np.min(all_dists) > tol
448450

449451
@abstractmethod
450-
def to(self, filename: str = "", fmt: str = "") -> str | None:
452+
def to(self, filename: str = "", fmt: FileFormats = "") -> str | None:
451453
"""Generates string representations (cif, json, poscar, ....) of SiteCollections (e.g.,
452454
molecules / structures). Should return str or None if written to a file.
453455
"""
454456
raise NotImplementedError
455457

456-
def to_file(self, filename: str = "", fmt: str = "") -> str | None:
458+
def to_file(self, filename: str = "", fmt: FileFormats = "") -> str | None:
457459
"""A more intuitive alias for .to()."""
458460
return self.to(filename, fmt)
459461

@@ -2653,7 +2655,7 @@ def from_dict(cls, dct: dict[str, Any], fmt: Literal["abivars"] | None = None) -
26532655
charge = dct.get("charge")
26542656
return cls.from_sites(sites, charge=charge, properties=dct.get("properties"))
26552657

2656-
def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
2658+
def to(self, filename: str | Path = "", fmt: FileFormats = "", **kwargs) -> str:
26572659
"""Outputs the structure to a file or string.
26582660
26592661
Args:
@@ -2663,7 +2665,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
26632665
fmt (str): Format to output to. Defaults to JSON unless filename
26642666
is provided. If fmt is specifies, it overrides whatever the
26652667
filename is. Options include "cif", "poscar", "cssr", "json",
2666-
"xsf", "mcsqs", "prismatic", "yaml", "fleur-inpgen".
2668+
"xsf", "mcsqs", "prismatic", "yaml", "yml", "fleur-inpgen".
26672669
Non-case sensitive.
26682670
**kwargs: Kwargs passthru to relevant methods. E.g., This allows
26692671
the passing of parameters like symprec to the
@@ -2673,7 +2675,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
26732675
str: String representation of molecule in given format. If a filename
26742676
is provided, the same string is written to the file.
26752677
"""
2676-
filename, fmt = str(filename), fmt.lower()
2678+
filename, fmt = str(filename), cast(FileFormats, fmt.lower())
26772679

26782680
if fmt == "cif" or fnmatch(filename.lower(), "*.cif*"):
26792681
from pymatgen.io.cif import CifWriter
@@ -2722,7 +2724,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
27222724
from pymatgen.io.prismatic import Prismatic
27232725

27242726
return Prismatic(self).to_str()
2725-
elif fmt == "yaml" or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"):
2727+
elif fmt in ("yaml", "yml") or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"):
27262728
yaml = YAML()
27272729
str_io = StringIO()
27282730
yaml.dump(self.as_dict(), str_io)
@@ -2747,7 +2749,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
27472749
else:
27482750
if fmt == "":
27492751
raise ValueError(f"Format not specified and could not infer from {filename=}")
2750-
raise ValueError(f"Invalid format={fmt!r}")
2752+
raise ValueError(f"Invalid {fmt=}, valid options are {get_args(FileFormats)}")
27512753

27522754
if filename:
27532755
writer.write_file(filename)
@@ -2757,7 +2759,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
27572759
def from_str( # type: ignore[override]
27582760
cls,
27592761
input_string: str,
2760-
fmt: Literal["cif", "poscar", "cssr", "json", "yaml", "xsf", "mcsqs", "res"],
2762+
fmt: FileFormats,
27612763
primitive: bool = False,
27622764
sort: bool = False,
27632765
merge_tol: float = 0.0,
@@ -2768,7 +2770,7 @@ def from_str( # type: ignore[override]
27682770
Args:
27692771
input_string (str): String to parse.
27702772
fmt (str): A file format specification. One of "cif", "poscar", "cssr",
2771-
"json", "yaml", "xsf", "mcsqs".
2773+
"json", "yaml", "yml", "xsf", "mcsqs", "res".
27722774
primitive (bool): Whether to find a primitive cell. Defaults to
27732775
False.
27742776
sort (bool): Whether to sort the sites in accordance to the default
@@ -2797,12 +2799,12 @@ def from_str( # type: ignore[override]
27972799
cssr = Cssr.from_str(input_string, **kwargs)
27982800
struct = cssr.structure
27992801
elif fmt_low == "json":
2800-
d = json.loads(input_string)
2801-
struct = Structure.from_dict(d)
2802-
elif fmt_low == "yaml":
2802+
dct = json.loads(input_string)
2803+
struct = Structure.from_dict(dct)
2804+
elif fmt_low in ("yaml", "yml"):
28032805
yaml = YAML()
2804-
d = yaml.load(input_string)
2805-
struct = Structure.from_dict(d)
2806+
dct = yaml.load(input_string)
2807+
struct = Structure.from_dict(dct)
28062808
elif fmt_low == "xsf":
28072809
from pymatgen.io.xcrysden import XSF
28082810

@@ -2825,7 +2827,7 @@ def from_str( # type: ignore[override]
28252827

28262828
struct = ResIO.structure_from_str(input_string, **kwargs)
28272829
else:
2828-
raise ValueError(f"Unrecognized format `{fmt}`!")
2830+
raise ValueError(f"Invalid {fmt=}, valid options are {get_args(FileFormats)}")
28292831

28302832
if sort:
28312833
struct = struct.get_sorted_structure()

tests/core/test_structure.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def test_get_dist_matrix(self):
818818
assert_allclose(self.struct.distance_matrix, ans)
819819

820820
def test_to_from_file_and_string(self):
821-
for fmt in ["cif", "json", "poscar", "cssr"]:
821+
for fmt in ("cif", "json", "poscar", "cssr"):
822822
struct = self.struct.to(fmt=fmt)
823823
assert struct is not None
824824
ss = IStructure.from_str(struct, fmt=fmt)
@@ -851,7 +851,7 @@ def test_to_from_file_and_string(self):
851851

852852
with pytest.raises(ValueError, match="Format not specified and could not infer from filename='whatever'"):
853853
self.struct.to(filename="whatever")
854-
with pytest.raises(ValueError, match="Invalid format='badformat'"):
854+
with pytest.raises(ValueError, match="Invalid fmt='badformat'"):
855855
self.struct.to(fmt="badformat")
856856

857857
self.struct.to(filename=(gz_json_path := "POSCAR.testing.gz"))
@@ -1284,7 +1284,7 @@ def test_to_from_abivars(self):
12841284

12851285
def test_to_from_file_string(self):
12861286
# to/from string
1287-
for fmt in ["cif", "json", "poscar", "cssr", "yaml", "xsf", "res"]:
1287+
for fmt in ("cif", "json", "poscar", "cssr", "yaml", "yml", "xsf", "res"):
12881288
struct = self.struct.to(fmt=fmt)
12891289
assert struct is not None
12901290
ss = Structure.from_str(struct, fmt=fmt)

0 commit comments

Comments
 (0)