Skip to content

Commit db20521

Browse files
Allow for writing of Structure.site_properties as _atom_site_ flags in CifWriter (#3550)
* fix mypy, fix ruff, tweak test_cif_writer_site_properties --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 3a754c1 commit db20521

File tree

5 files changed

+52
-29
lines changed

5 files changed

+52
-29
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ci:
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.1.11
11+
rev: v0.1.13
1212
hooks:
1313
- id: ruff
1414
args: [--fix, --unsafe-fixes]

pymatgen/io/abinit/abitimer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,11 +888,11 @@ def scatter_hist(self, ax: plt.Axes = None, **kwargs):
888888
# axHistx.axis["bottom"].major_ticklabels.set_visible(False)
889889
axHistx.set_yticks([0, 50, 100])
890890
for tl in axHistx.get_xticklabels():
891-
tl.set_visible(False) # noqa: FBT003
891+
tl.set_visible(False)
892892

893893
# axHisty.axis["left"].major_ticklabels.set_visible(False)
894894
for tl in axHisty.get_yticklabels():
895-
tl.set_visible(False) # noqa: FBT003
895+
tl.set_visible(False)
896896
axHisty.set_xticks([0, 50, 100])
897897

898898
# plt.draw()

pymatgen/io/abinit/netcdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self, path):
9191
# Slicing a ncvar returns a MaskedArrray and this is really annoying
9292
# because it can lead to unexpected behavior in e.g. calls to np.matmul!
9393
# See also https://github.com/Unidata/netcdf4-python/issues/785
94-
self.rootgrp.set_auto_mask(False) # noqa: FBT003
94+
self.rootgrp.set_auto_mask(False)
9595

9696
def __enter__(self):
9797
"""Activated when used in the with statement."""

pymatgen/io/cif.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import textwrap
99
import warnings
10-
from collections import deque
10+
from collections import defaultdict, deque
1111
from datetime import datetime
1212
from functools import partial
1313
from inspect import getfullargspec as getargspec
@@ -1313,13 +1313,14 @@ class CifWriter:
13131313

13141314
def __init__(
13151315
self,
1316-
struct,
1317-
symprec=None,
1318-
write_magmoms=False,
1319-
significant_figures=8,
1320-
angle_tolerance=5.0,
1321-
refine_struct=True,
1322-
):
1316+
struct: Structure,
1317+
symprec: float | None = None,
1318+
write_magmoms: bool = False,
1319+
significant_figures: int = 8,
1320+
angle_tolerance: float = 5,
1321+
refine_struct: bool = True,
1322+
write_site_properties: bool = False,
1323+
) -> None:
13231324
"""
13241325
Args:
13251326
struct (Structure): structure to write
@@ -1335,14 +1336,16 @@ def __init__(
13351336
is not None.
13361337
refine_struct: Used only if symprec is not None. If True, get_refined_structure
13371338
is invoked to convert input structure from primitive to conventional.
1339+
write_site_properties (bool): Whether to write the Structure.site_properties
1340+
to the CIF as _atom_site_{property name}. Defaults to False.
13381341
"""
13391342
if write_magmoms and symprec:
13401343
warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.")
13411344
symprec = None
13421345

13431346
format_str = f"{{:.{significant_figures}f}}"
13441347

1345-
block = {}
1348+
block: dict[str, Any] = {}
13461349
loops = []
13471350
spacegroup = ("P 1", 1)
13481351
if symprec is not None:
@@ -1367,7 +1370,7 @@ def __init__(
13671370
block["_chemical_formula_sum"] = no_oxi_comp.formula
13681371
block["_cell_volume"] = format_str.format(lattice.volume)
13691372

1370-
_reduced_comp, fu = no_oxi_comp.get_reduced_composition_and_factor()
1373+
_, fu = no_oxi_comp.get_reduced_composition_and_factor()
13711374
block["_cell_formula_units_Z"] = str(int(fu))
13721375

13731376
if symprec is None:
@@ -1388,12 +1391,12 @@ def __init__(
13881391
loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"])
13891392

13901393
try:
1391-
symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)}
1392-
block["_atom_type_symbol"] = list(symbol_to_oxinum)
1393-
block["_atom_type_oxidation_number"] = symbol_to_oxinum.values()
1394+
symbol_to_oxi_num = {str(el): float(el.oxi_state or 0) for el in sorted(comp.elements)}
1395+
block["_atom_type_symbol"] = list(symbol_to_oxi_num)
1396+
block["_atom_type_oxidation_number"] = symbol_to_oxi_num.values()
13941397
loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"])
13951398
except (TypeError, AttributeError):
1396-
symbol_to_oxinum = {el.symbol: 0 for el in sorted(comp.elements)}
1399+
symbol_to_oxi_num = {el.symbol: 0 for el in sorted(comp.elements)}
13971400

13981401
atom_site_type_symbol = []
13991402
atom_site_symmetry_multiplicity = []
@@ -1406,6 +1409,7 @@ def __init__(
14061409
atom_site_moment_crystalaxis_x = []
14071410
atom_site_moment_crystalaxis_y = []
14081411
atom_site_moment_crystalaxis_z = []
1412+
atom_site_properties: dict[str, list] = defaultdict(list)
14091413
count = 0
14101414
if symprec is None:
14111415
for site in struct:
@@ -1437,6 +1441,10 @@ def __init__(
14371441
atom_site_moment_crystalaxis_y.append(format_str.format(moment[1]))
14381442
atom_site_moment_crystalaxis_z.append(format_str.format(moment[2]))
14391443

1444+
if write_site_properties:
1445+
for key, val in site.properties.items():
1446+
atom_site_properties[key].append(format_str.format(val))
1447+
14401448
count += 1
14411449
else:
14421450
# The following just presents a deterministic ordering.
@@ -1475,17 +1483,21 @@ def __init__(
14751483
block["_atom_site_fract_y"] = atom_site_fract_y
14761484
block["_atom_site_fract_z"] = atom_site_fract_z
14771485
block["_atom_site_occupancy"] = atom_site_occupancy
1478-
loops.append(
1479-
[
1480-
"_atom_site_type_symbol",
1481-
"_atom_site_label",
1482-
"_atom_site_symmetry_multiplicity",
1483-
"_atom_site_fract_x",
1484-
"_atom_site_fract_y",
1485-
"_atom_site_fract_z",
1486-
"_atom_site_occupancy",
1487-
]
1488-
)
1486+
loop_labels = [
1487+
"_atom_site_type_symbol",
1488+
"_atom_site_label",
1489+
"_atom_site_symmetry_multiplicity",
1490+
"_atom_site_fract_x",
1491+
"_atom_site_fract_y",
1492+
"_atom_site_fract_z",
1493+
"_atom_site_occupancy",
1494+
]
1495+
if write_site_properties:
1496+
for key, vals in atom_site_properties.items():
1497+
block[f"_atom_site_{key}"] = vals
1498+
loop_labels += [f"_atom_site_{key}"]
1499+
loops.append(loop_labels)
1500+
14891501
if write_magmoms:
14901502
block["_atom_site_moment_label"] = atom_site_moment_label
14911503
block["_atom_site_moment_crystalaxis_x"] = atom_site_moment_crystalaxis_x

tests/io/test_cif.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,17 @@ def test_cif_writer_write_file(self):
870870
assert len(read_structs) == 2
871871
assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"]
872872

873+
def test_cif_writer_site_properties(self):
874+
struct = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR")
875+
struct.add_site_property(label := "hello", [1.0] * (len(struct) - 1) + [-1.0])
876+
out_path = f"{self.tmp_path}/test2.cif"
877+
CifWriter(struct, write_site_properties=True).write_file(out_path)
878+
with open(out_path) as file:
879+
cif_str = file.read()
880+
assert f"_atom_site_occupancy\n _atom_site_{label}\n" in cif_str
881+
assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0" in cif_str
882+
assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0" in cif_str
883+
873884

874885
class TestMagCif(PymatgenTest):
875886
def setUp(self):

0 commit comments

Comments
 (0)