diff --git a/src/pymatgen/electronic_structure/cohp.py b/src/pymatgen/electronic_structure/cohp.py index 9e1d623c230..863bb568ad8 100644 --- a/src/pymatgen/electronic_structure/cohp.py +++ b/src/pymatgen/electronic_structure/cohp.py @@ -1405,7 +1405,11 @@ def as_dict(self) -> dict[str, Any]: { key: { "icohp": {str(spin): value for spin, value in val["icohp"].items()}, - "orbitals": [[n, int(orb)] for n, orb in val["orbitals"]], + "orbitals": ( + [[n, int(orb)] for n, orb in val["orbitals"]] + if isinstance(val["orbitals"][0], (list, tuple)) + else list(val["orbitals"]) # Handle LCFO orbitals + ), } for key, val in entry.items() } @@ -1433,17 +1437,24 @@ def from_dict(cls, dct: dict[str, Any]) -> Self: for key in lab_orb_icohp[orb]: sub_dict = {} if key == "icohp": - sub_dict[key] = { - Spin.up: lab_orb_icohp[orb][key]["1"], - Spin.down: lab_orb_icohp[orb][key]["-1"], - } + if dct.get("is_spin_polarized"): + sub_dict[key] = { + Spin.up: lab_orb_icohp[orb][key]["1"], + Spin.down: lab_orb_icohp[orb][key]["-1"], + } + else: + sub_dict[key] = {Spin.up: lab_orb_icohp[orb][key]["1"]} if key == "orbitals": orb_temp = [] for item in lab_orb_icohp[orb][key]: - item[1] = Orbital(item[1]) - orb_temp.append(item) + # Handle LCFO orbitals + if isinstance(item, (list, tuple)): + item[1] = Orbital(item[1]) + orb_temp.append(tuple(item)) + else: + orb_temp.append(item) sub_dict[key] = orb_temp # type: ignore[assignment] new_list_orb[bond_num][orb].update(sub_dict) diff --git a/tests/io/lobster/test_outputs.py b/tests/io/lobster/test_outputs.py index 6362b40d4f4..0684127f680 100644 --- a/tests/io/lobster/test_outputs.py +++ b/tests/io/lobster/test_outputs.py @@ -2169,14 +2169,15 @@ def test_values(self): assert self.icohp_lcfo_non_orbitalwise.icohplist["16"]["icohp"][Spin.down] == approx(-0.29842) def test_msonable(self): - dict_data = self.icobi_orbitalwise_spinpolarized.as_dict() - icohplist_from_dict = Icohplist.from_dict(dict_data) - all_attributes = vars(self.icobi_orbitalwise_spinpolarized) - for attr_name, attr_value in all_attributes.items(): - if isinstance(attr_value, IcohpCollection): - assert getattr(icohplist_from_dict, attr_name).as_dict() == attr_value.as_dict() - else: - assert getattr(icohplist_from_dict, attr_name) == attr_value + for icohplist_obj in [self.icobi_orbitalwise_spinpolarized, self.icohp_nacl_511_nsp]: + dict_data = icohplist_obj.as_dict() + icohplist_from_dict = Icohplist.from_dict(dict_data) + all_attributes = vars(icohplist_obj) + for attr_name, attr_value in all_attributes.items(): + if isinstance(attr_value, IcohpCollection): + assert getattr(icohplist_from_dict, attr_name).as_dict() == attr_value.as_dict() + else: + assert getattr(icohplist_from_dict, attr_name) == attr_value def test_missing_trailing_newline(self): fname = f"{self.tmp_path}/icohplist"