Skip to content

Commit f1ac4aa

Browse files
stefsmeetsjanosh
andauthored
Propagate labels through various Structure operations (#3183)
* Fix for multiple species with different names * Propagate site labels in structure operations * use pymatgen.core.trajectory.Vector3D where applicable * remove _in suffix from _unique_coords() args * Propagate labels in more methods * Fix type hints * Write labels back to cif * Remove assert statement * Fix typo * Change single character variable * fix type hint for labels: list | None -> labels: Sequence[str | None] | None --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent ae161fe commit f1ac4aa

File tree

9 files changed

+167
-73
lines changed

9 files changed

+167
-73
lines changed

pymatgen/analysis/gb/grain.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
if TYPE_CHECKING:
2222
from numpy.typing import ArrayLike
2323

24+
from pymatgen.core.trajectory import Vector3D
2425
from pymatgen.util.typing import CompositionLike
2526

2627
# This module implements representations of grain boundaries, as well as
@@ -55,10 +56,10 @@ def __init__(
5556
lattice: np.ndarray | Lattice,
5657
species: Sequence[CompositionLike],
5758
coords: Sequence[ArrayLike],
58-
rotation_axis: tuple[float, float, float],
59+
rotation_axis: Vector3D,
5960
rotation_angle: float,
60-
gb_plane: tuple[float, float, float],
61-
join_plane: tuple[float, float, float],
61+
gb_plane: Vector3D,
62+
join_plane: Vector3D,
6263
init_cell: Structure,
6364
vacuum_thickness: float,
6465
ab_shift: tuple[float, float],

pymatgen/core/lattice.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
if TYPE_CHECKING:
2424
from numpy.typing import ArrayLike
2525

26+
from pymatgen.core.trajectory import Vector3D
27+
2628
__author__ = "Shyue Ping Ong, Michael Kocher"
2729
__copyright__ = "Copyright 2011, The Materials Project"
2830
__maintainer__ = "Shyue Ping Ong"
@@ -67,7 +69,7 @@ def __init__(self, matrix: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True
6769
self._pbc = tuple(pbc)
6870

6971
@property
70-
def lengths(self) -> tuple[float, float, float]:
72+
def lengths(self) -> Vector3D:
7173
"""
7274
Lattice lengths.
7375
@@ -76,7 +78,7 @@ def lengths(self) -> tuple[float, float, float]:
7678
return tuple(np.sqrt(np.sum(self._matrix**2, axis=1)).tolist()) # type: ignore
7779

7880
@property
79-
def angles(self) -> tuple[float, float, float]:
81+
def angles(self) -> Vector3D:
8082
"""
8183
Lattice angles.
8284
@@ -414,7 +416,7 @@ def c(self) -> float:
414416
return self.lengths[2]
415417

416418
@property
417-
def abc(self) -> tuple[float, float, float]:
419+
def abc(self) -> Vector3D:
418420
"""Lengths of the lattice vectors, i.e. (a, b, c)."""
419421
return self.lengths
420422

pymatgen/core/sites.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
self._species: Composition = species # type: ignore
7272
self.coords: np.ndarray = coords # type: ignore
7373
self.properties: dict = properties or {}
74-
self.label = label if label else self.species_string
74+
self._label = label
7575

7676
def __getattr__(self, attr):
7777
# overriding getattr doesn't play nicely with pickle, so we can't use self._properties
@@ -86,7 +86,7 @@ def species(self) -> Composition:
8686
return self._species
8787

8888
@species.setter
89-
def species(self, species: SpeciesLike | CompositionLike):
89+
def species(self, species: SpeciesLike | CompositionLike) -> None:
9090
if not isinstance(species, Composition):
9191
try:
9292
species = Composition({get_el_sp(species): 1}) # type: ignore
@@ -97,13 +97,22 @@ def species(self, species: SpeciesLike | CompositionLike):
9797
raise ValueError("Species occupancies sum to more than 1!")
9898
self._species = species
9999

100+
@property
101+
def label(self) -> str:
102+
"""Site label."""
103+
return self._label if self._label is not None else self.species_string
104+
105+
@label.setter
106+
def label(self, label: str) -> None:
107+
self._label = label
108+
100109
@property
101110
def x(self) -> float:
102111
"""Cartesian x coordinate."""
103112
return self.coords[0]
104113

105114
@x.setter
106-
def x(self, x: float):
115+
def x(self, x: float) -> None:
107116
self.coords[0] = x
108117

109118
@property
@@ -112,7 +121,7 @@ def y(self) -> float:
112121
return self.coords[1]
113122

114123
@y.setter
115-
def y(self, y: float):
124+
def y(self, y: float) -> None:
116125
self.coords[1] = y
117126

118127
@property
@@ -121,7 +130,7 @@ def z(self) -> float:
121130
return self.coords[2]
122131

123132
@z.setter
124-
def z(self, z: float):
133+
def z(self, z: float) -> None:
125134
self.coords[2] = z
126135

127136
def distance(self, other) -> float:
@@ -345,7 +354,7 @@ def __init__(
345354
self._species: Composition = species # type: ignore
346355
self._coords: np.ndarray | None = None
347356
self.properties: dict = properties or {}
348-
self.label = label if label else self.species_string
357+
self._label = label
349358

350359
def __hash__(self) -> int:
351360
"""
@@ -360,7 +369,7 @@ def lattice(self) -> Lattice:
360369
return self._lattice
361370

362371
@lattice.setter
363-
def lattice(self, lattice: Lattice):
372+
def lattice(self, lattice: Lattice) -> None:
364373
"""Sets Lattice associated with PeriodicSite."""
365374
self._lattice = lattice
366375
self._coords = self._lattice.get_cartesian_coords(self._frac_coords)
@@ -373,7 +382,7 @@ def coords(self) -> np.ndarray:
373382
return self._coords
374383

375384
@coords.setter
376-
def coords(self, coords):
385+
def coords(self, coords) -> None:
377386
"""Set Cartesian coordinates."""
378387
self._coords = np.array(coords)
379388
self._frac_coords = self._lattice.get_fractional_coords(self._coords)
@@ -384,7 +393,7 @@ def frac_coords(self) -> np.ndarray:
384393
return self._frac_coords
385394

386395
@frac_coords.setter
387-
def frac_coords(self, frac_coords):
396+
def frac_coords(self, frac_coords) -> None:
388397
"""Set fractional coordinates."""
389398
self._frac_coords = np.array(frac_coords)
390399
self._coords = self._lattice.get_cartesian_coords(self._frac_coords)
@@ -395,7 +404,7 @@ def a(self) -> float:
395404
return self._frac_coords[0]
396405

397406
@a.setter
398-
def a(self, a: float):
407+
def a(self, a: float) -> None:
399408
self._frac_coords[0] = a
400409
self._coords = self._lattice.get_cartesian_coords(self._frac_coords)
401410

@@ -405,7 +414,7 @@ def b(self) -> float:
405414
return self._frac_coords[1]
406415

407416
@b.setter
408-
def b(self, b: float):
417+
def b(self, b: float) -> None:
409418
self._frac_coords[1] = b
410419
self._coords = self._lattice.get_cartesian_coords(self._frac_coords)
411420

@@ -415,7 +424,7 @@ def c(self) -> float:
415424
return self._frac_coords[2]
416425

417426
@c.setter
418-
def c(self, c: float):
427+
def c(self, c: float) -> None:
419428
self._frac_coords[2] = c
420429
self._coords = self._lattice.get_cartesian_coords(self._frac_coords)
421430

@@ -425,7 +434,7 @@ def x(self) -> float:
425434
return self.coords[0]
426435

427436
@x.setter
428-
def x(self, x: float):
437+
def x(self, x: float) -> None:
429438
self.coords[0] = x
430439
self._frac_coords = self._lattice.get_fractional_coords(self.coords)
431440

@@ -435,7 +444,7 @@ def y(self) -> float:
435444
return self.coords[1]
436445

437446
@y.setter
438-
def y(self, y: float):
447+
def y(self, y: float) -> None:
439448
self.coords[1] = y
440449
self._frac_coords = self._lattice.get_fractional_coords(self.coords)
441450

@@ -445,7 +454,7 @@ def z(self) -> float:
445454
return self.coords[2]
446455

447456
@z.setter
448-
def z(self, z: float):
457+
def z(self, z: float) -> None:
449458
self.coords[2] = z
450459
self._frac_coords = self._lattice.get_fractional_coords(self.coords)
451460

@@ -455,7 +464,7 @@ def to_unit_cell(self, in_place=False) -> PeriodicSite | None:
455464
if in_place:
456465
self.frac_coords = np.array(frac_coords)
457466
return None
458-
return PeriodicSite(self.species, frac_coords, self.lattice, properties=self.properties)
467+
return PeriodicSite(self.species, frac_coords, self.lattice, properties=self.properties, label=self.label)
459468

460469
def is_periodic_image(self, other: PeriodicSite, tolerance: float = 1e-8, check_lattice: bool = True) -> bool:
461470
"""

0 commit comments

Comments
 (0)