Skip to content

Commit 5d6f566

Browse files
AseAtomsAdaptor: Retain tags property when interconverting Atoms and Structure/Molecule (#3151)
* add support for tags
1 parent 9122d21 commit 5d6f566

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

pymatgen/core/tests/test_structure.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,18 +1420,13 @@ def test_calculate_chgnet(self):
14201420
assert preds["magmoms"] == approx([0.00262399, 0.00262396], abs=1e-5)
14211421
assert np.linalg.norm(preds["forces"]) == approx(1.998941843e-5, abs=1e-3)
14221422
assert not hasattr(calculator, "dynamics"), "static calculation should not have dynamics"
1423-
assert {*calculator.__dict__} == {
1424-
"atoms",
1425-
"results",
1426-
"parameters",
1427-
"_directory",
1428-
"prefix",
1429-
"name",
1430-
"get_spin_polarized",
1431-
"device",
1432-
"model",
1433-
"stress_weight",
1434-
}
1423+
assert "atoms" in calculator.__dict__
1424+
assert "results" in calculator.__dict__
1425+
assert "parameters" in calculator.__dict__
1426+
assert "get_spin_polarized" in calculator.__dict__
1427+
assert "device" in calculator.__dict__
1428+
assert "model" in calculator.__dict__
1429+
assert "stress_weight" in calculator.__dict__
14351430
assert len(calculator.parameters) == 0
14361431
assert isinstance(calculator.atoms, Atoms)
14371432
assert len(calculator.atoms) == len(struct)

pymatgen/io/ase.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms:
7272

7373
atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs)
7474

75+
if "tags" in structure.site_properties:
76+
atoms.set_tags(structure.site_properties["tags"])
77+
7578
# Set the site magmoms in the ASE Atoms object
7679
# Note: ASE distinguishes between initial and converged
7780
# magnetic moment site properties, whereas pymatgen does not. Therefore, we
@@ -181,6 +184,9 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs)
181184
positions = atoms.get_positions()
182185
lattice = atoms.get_cell()
183186

187+
# Get the tags
188+
tags = atoms.get_tags() if atoms.has("tags") else None
189+
184190
# Get the (final) site magmoms and charges from the ASE Atoms object.
185191
if getattr(atoms, "calc", None) is not None and getattr(atoms.calc, "results", None) is not None:
186192
charges = atoms.calc.results.get("charges")
@@ -247,6 +253,8 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs)
247253
structure.add_site_property("magmom", initial_magmoms)
248254
if sel_dyn is not None and ~np.all(sel_dyn):
249255
structure.add_site_property("selective_dynamics", sel_dyn)
256+
if tags is not None:
257+
structure.add_site_property("tags", tags)
250258

251259
# Add oxidation states by site
252260
if oxi_states is not None:

pymatgen/io/tests/test_ase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_get_atoms_from_structure(self):
2727
assert atoms.get_pbc().all()
2828
assert atoms.get_chemical_symbols() == [s.species_string for s in structure]
2929
assert not atoms.has("initial_magmoms")
30+
assert not atoms.has("initial_charges")
3031
assert atoms.calc is None
3132

3233
p = Poscar.from_file(os.path.join(PymatgenTest.TEST_FILES_DIR, "POSCAR"))
@@ -248,6 +249,7 @@ def test_back_forth(self):
248249
# Atoms --> Structure --> Atoms --> Structure
249250
atoms = read(os.path.join(PymatgenTest.TEST_FILES_DIR, "OUTCAR"))
250251
atoms.info = {"test": "hi"}
252+
atoms.set_tags([1] * len(atoms))
251253
atoms.set_constraint(FixAtoms(mask=[True] * len(atoms)))
252254
atoms.set_initial_charges([1.0] * len(atoms))
253255
atoms.set_initial_magnetic_moments([2.0] * len(atoms))
@@ -281,6 +283,7 @@ def test_back_forth(self):
281283
atoms.set_initial_charges([1.0] * len(atoms))
282284
atoms.set_initial_magnetic_moments([2.0] * len(atoms))
283285
atoms.set_array("prop", np.array([3.0] * len(atoms)))
286+
atoms.set_tags([1] * len(atoms))
284287
molecule = aio.AseAtomsAdaptor.get_molecule(atoms)
285288
atoms_back = aio.AseAtomsAdaptor.get_atoms(molecule)
286289
molecule_back = aio.AseAtomsAdaptor.get_molecule(atoms_back)

0 commit comments

Comments
 (0)