Skip to content

Commit 30ca255

Browse files
author
Shyue Ping Ong
committed
Fixes assymetric == between Element and Species. Fixes #3504.
1 parent 7575993 commit 30ca255

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

pymatgen/core/periodic_table.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def ground_state_term_symbol(self):
454454
return J_sorted_terms[-1][0]
455455

456456
def __eq__(self, other: object) -> bool:
457-
return isinstance(other, Element) and self.Z == other.Z
457+
return isinstance(self, Element) and isinstance(other, Element) and self.Z == other.Z
458458

459459
def __hash__(self) -> int:
460460
return self.Z
@@ -895,7 +895,12 @@ class Species(MSONable, Stringify):
895895

896896
STRING_MODE = "SUPERSCRIPT"
897897

898-
def __init__(self, symbol: SpeciesLike, oxidation_state: float | None = None, spin: float | None = None) -> None:
898+
def __init__(
899+
self,
900+
symbol: SpeciesLike,
901+
oxidation_state: float | None = None,
902+
spin: float | None = None,
903+
) -> None:
899904
"""
900905
Args:
901906
symbol (str): Element symbol optionally incl. oxidation state. E.g. Fe, Fe2+, O2-.
@@ -1138,7 +1143,9 @@ def get_shannon_radius(
11381143
return data[f"{radius_type}_radius"]
11391144

11401145
def get_crystal_field_spin(
1141-
self, coordination: Literal["oct", "tet"] = "oct", spin_config: Literal["low", "high"] = "high"
1146+
self,
1147+
coordination: Literal["oct", "tet"] = "oct",
1148+
spin_config: Literal["low", "high"] = "high",
11421149
) -> float:
11431150
"""Calculate the crystal field spin based on coordination and spin
11441151
configuration. Only works for transition metal species.

tests/core/test_periodic_table.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,9 @@ def test_eq(self):
383383
assert self.specie4 != self.specie3
384384
assert self.specie1 != Element("Fe")
385385
assert Element("Fe") != self.specie1
386+
assert Element("Fe") == Element("Fe")
387+
assert Species("Fe", 0) != Element("Fe")
388+
assert Element("Fe") != Species("Fe", 0)
386389

387390
def test_cmp(self):
388391
assert self.specie1 < self.specie2, "Fe2+ should be < Fe3+"
@@ -423,7 +426,10 @@ def test_get_crystal_field_spin(self):
423426

424427
for elem in ("Li+", "Ge4+", "H+"):
425428
symbol = Species(elem).symbol
426-
with pytest.raises(AttributeError, match=f"Invalid element {symbol} for crystal field calculation"):
429+
with pytest.raises(
430+
AttributeError,
431+
match=f"Invalid element {symbol} for crystal field calculation",
432+
):
427433
Species(elem).get_crystal_field_spin()
428434
with pytest.raises(AttributeError, match="Invalid oxidation state 10 for element Fe"):
429435
Species("Fe", 10).get_crystal_field_spin()

0 commit comments

Comments
 (0)