Skip to content

Commit d1f199d

Browse files
Bump test coverage, revise absorption rester, remove deprecated methods (#1042)
2 parents a8210aa + 6ff1e3c commit d1f199d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+770
-345
lines changed

mp_api/client/core/utils.py

Lines changed: 11 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from __future__ import annotations
22

3-
import re
43
from typing import TYPE_CHECKING, Literal
54

65
import orjson
76
from emmet.core import __version__ as _EMMET_CORE_VER
7+
from emmet.core.mpid_ext import validate_identifier
88
from monty.json import MontyDecoder
99
from packaging.version import parse as parse_version
1010

1111
from mp_api.client.core.settings import MAPIClientSettings
1212

1313
if TYPE_CHECKING:
14-
from monty.json import MSONable
14+
from typing import Any
1515

1616

1717
def _compare_emmet_ver(
@@ -23,6 +23,10 @@ def _compare_emmet_ver(
2323
_compare_emmet_ver("0.84.0rc0","<") returns
2424
emmet.core.__version__ < "0.84.0rc0"
2525
26+
This function may not be used anywhere in the client, but it should
27+
be preserved for future use, in case some degree of backwards
28+
compatibility or feature buy-in is needed.
29+
2630
Parameters
2731
-----------
2832
ref_version : str
@@ -36,41 +40,17 @@ def _compare_emmet_ver(
3640
)(parse_version(ref_version))
3741

3842

39-
if _compare_emmet_ver("0.85.0", ">="):
40-
from emmet.core.mpid_ext import validate_identifier
41-
else:
42-
validate_identifier = None
43-
44-
45-
def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"):
43+
def load_json(
44+
json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"
45+
) -> Any:
4646
"""Utility to load json in consistent manner."""
4747
data = orjson.loads(
4848
json_like if isinstance(json_like, bytes) else json_like.encode(encoding)
4949
)
5050
return MontyDecoder().process_decoded(data) if deser else data
5151

5252

53-
def _legacy_id_validation(id_list: list[str]) -> list[str]:
54-
"""Legacy utility to validate IDs, pre-AlphaID transition.
55-
56-
This function is temporarily maintained to allow for
57-
backwards compatibility with older versions of emmet, and will
58-
not be preserved.
59-
"""
60-
pattern = "(mp|mvc|mol|mpcule)-.*"
61-
if malformed_ids := {
62-
entry for entry in id_list if re.match(pattern, entry) is None
63-
}:
64-
raise ValueError(
65-
f"{'Entry' if len(malformed_ids) == 1 else 'Entries'}"
66-
f" {', '.join(malformed_ids)}"
67-
f"{'is' if len(malformed_ids) == 1 else 'are'} not formatted correctly!"
68-
)
69-
70-
return id_list
71-
72-
73-
def validate_ids(id_list: list[str]):
53+
def validate_ids(id_list: list[str]) -> list[str]:
7454
"""Function to validate material and task IDs.
7555
7656
Args:
@@ -91,36 +71,4 @@ def validate_ids(id_list: list[str]):
9171
# TODO: after the transition to AlphaID in the document models,
9272
# The following line should be changed to
9373
# return [validate_identifier(idx,serialize=True) for idx in id_list]
94-
if validate_identifier:
95-
return [str(validate_identifier(idx)) for idx in id_list]
96-
return _legacy_id_validation(id_list)
97-
98-
99-
def allow_msonable_dict(monty_cls: type[MSONable]):
100-
"""Patch Monty to allow for dict values for MSONable."""
101-
102-
def validate_monty(cls, v, _):
103-
"""Stub validator for MSONable as a dictionary only."""
104-
if isinstance(v, cls):
105-
return v
106-
elif isinstance(v, dict):
107-
# Just validate the simple Monty Dict Model
108-
errors = []
109-
if v.get("@module", "") != monty_cls.__module__:
110-
errors.append("@module")
111-
112-
if v.get("@class", "") != monty_cls.__name__:
113-
errors.append("@class")
114-
115-
if len(errors) > 0:
116-
raise ValueError(
117-
"Missing Monty seriailzation fields in dictionary: {errors}"
118-
)
119-
120-
return v
121-
else:
122-
raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")
123-
124-
monty_cls.validate_monty_v2 = classmethod(validate_monty)
125-
126-
return monty_cls
74+
return [str(validate_identifier(idx)) for idx in id_list]

mp_api/client/mprester.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from emmet.core.mpid import MPID, AlphaID
1212
from emmet.core.settings import EmmetSettings
1313
from emmet.core.tasks import TaskDoc
14+
from emmet.core.types.enums import ThermoType
1415
from emmet.core.vasp.calc_types import CalcType
1516
from packaging import version
1617
from pymatgen.analysis.phase_diagram import PhaseDiagram
@@ -25,7 +26,7 @@
2526
from mp_api.client.core import BaseRester, MPRestError
2627
from mp_api.client.core._oxygen_evolution import OxygenEvolution
2728
from mp_api.client.core.settings import MAPIClientSettings
28-
from mp_api.client.core.utils import _compare_emmet_ver, load_json, validate_ids
29+
from mp_api.client.core.utils import load_json, validate_ids
2930
from mp_api.client.routes import GeneralStoreRester, MessagesRester, UserSettingsRester
3031
from mp_api.client.routes.materials import (
3132
AbsorptionRester,
@@ -60,11 +61,6 @@
6061
from mp_api.client.routes.materials.materials import MaterialsRester
6162
from mp_api.client.routes.molecules import MoleculeRester
6263

63-
if _compare_emmet_ver("0.85.0", ">="):
64-
from emmet.core.types.enums import ThermoType
65-
else:
66-
from emmet.core.thermo import ThermoType
67-
6864
if TYPE_CHECKING:
6965
from typing import Any, Literal
7066

@@ -225,6 +221,9 @@ def __init__(
225221
"chemenv",
226222
]
227223

224+
if not self.endpoint.endswith("/"):
225+
self.endpoint += "/"
226+
228227
# Check if emmet version of server is compatible
229228
emmet_version = MPRester.get_emmet_version(self.endpoint)
230229

@@ -239,9 +238,6 @@ def __init__(
239238
if notify_db_version:
240239
raise NotImplementedError("This has not yet been implemented.")
241240

242-
if not self.endpoint.endswith("/"):
243-
self.endpoint += "/"
244-
245241
# Dynamically set rester attributes.
246242
# First, materials and molecules top level resters are set.
247243
# Nested rested are then setup to be loaded dynamically with custom __getattr__ functions.

mp_api/client/routes/materials/absorption.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ class AbsorptionRester(BaseRester):
1616
def search(
1717
self,
1818
material_ids: str | list[str] | None = None,
19-
chemsys: str | list[str] | None = None,
20-
elements: list[str] | None = None,
21-
exclude_elements: list[str] | None = None,
22-
formula: list[str] | None = None,
19+
num_sites: int | tuple[int, int] | None = None,
20+
num_elements: int | tuple[int, int] | None = None,
21+
volume: float | tuple[float, float] | None = None,
22+
density: float | tuple[float, float] | None = None,
23+
band_gap: float | tuple[float, float] | None = None,
2324
num_chunks: int | None = None,
2425
chunk_size: int = 1000,
2526
all_fields: bool = True,
@@ -28,14 +29,26 @@ def search(
2829
"""Query for optical absorption spectra data.
2930
3031
Arguments:
31-
material_ids (str, List[str]): Search for optical absorption data associated with the specified Material IDs
32-
chemsys (str, List[str]): A chemical system or list of chemical systems
33-
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]).
34-
elements (List[str]): A list of elements.
35-
exclude_elements (List[str]): A list of elements to exclude.
36-
formula (str, List[str]): A formula including anonymized formula
37-
or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed
38-
(e.g., [Fe2O3, ABO3]).
32+
material_ids (str, List[str]):
33+
Search for optical absorption data associated with the
34+
specified Material ID(s)
35+
num_sites (int, tuple[int, int]):
36+
Search with a single number or a range of number of sites
37+
in the structure.
38+
num_elements (int, tuple[int, int]):
39+
Search with a single number or a range of number of distinct
40+
elements in the structure.
41+
volume (float, tuple[float, float]):
42+
Search with a single number or a range of structural
43+
(lattice) volumes in ų.
44+
If a single number, an uncertainty of ±0.01 is automatically used.
45+
density (float, tuple[float, float]):
46+
Search with a single number or a range of structural
47+
(lattice) densities, in g/cm³.
48+
If a single number, an uncertainty of ±0.01 is automatically used.
49+
band_gap (float, tuple[float, float]):
50+
Search with a single number or a range of band gaps in eV.
51+
If a single number, an uncertainty of ±0.01 is automatically used.
3952
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
4053
chunk_size (int): Number of data entries per chunk.
4154
all_fields (bool): Whether to return all fields in the document. Defaults to True.
@@ -46,23 +59,27 @@ def search(
4659
"""
4760
query_params = defaultdict(dict) # type: dict
4861

49-
if formula:
50-
if isinstance(formula, str):
51-
formula = [formula]
52-
53-
query_params.update({"formula": ",".join(formula)})
54-
55-
if chemsys:
56-
if isinstance(chemsys, str):
57-
chemsys = [chemsys]
58-
59-
query_params.update({"chemsys": ",".join(chemsys)})
60-
61-
if elements:
62-
query_params.update({"elements": ",".join(elements)})
63-
64-
if exclude_elements:
65-
query_params.update({"exclude_elements": ",".join(exclude_elements)})
62+
aliased = {
63+
"num_sites": "nsites",
64+
"num_elements": "nelements",
65+
"band_gap": "bandgap",
66+
}
67+
user_query = locals()
68+
for k in ("num_sites", "num_elements", "volume", "density", "band_gap"):
69+
if (value := user_query.get(k)) is not None:
70+
if k in ("num_sites", "num_elements") and isinstance(value, int):
71+
value = (value, value)
72+
elif k in ("volume", "density", "band_gap") and isinstance(
73+
value, int | float
74+
):
75+
value = (value - 1e-2, value + 1e-2)
76+
77+
query_params.update(
78+
{
79+
f"{aliased.get(k,k)}_min": value[0],
80+
f"{aliased.get(k,k)}_max": value[1],
81+
}
82+
)
6683

6784
if material_ids:
6885
if isinstance(material_ids, str):
@@ -77,7 +94,6 @@ def search(
7794
}
7895

7996
return super()._search(
80-
formulae=formula,
8197
num_chunks=num_chunks,
8298
chunk_size=chunk_size,
8399
all_fields=all_fields,

0 commit comments

Comments
 (0)