diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 3ee0b32e..035f0fe6 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -23,6 +23,7 @@ TYPE_CHECKING, ForwardRef, Generic, + Optional, TypeVar, get_args, ) @@ -40,7 +41,7 @@ from urllib3.util.retry import Retry from mp_api.client.core.settings import MAPIClientSettings -from mp_api.client.core.utils import api_sanitize, validate_ids +from mp_api.client.core.utils import validate_ids try: import boto3 @@ -57,6 +58,8 @@ if TYPE_CHECKING: from typing import Any, Callable + from pydantic.fields import FieldInfo + try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover @@ -150,12 +153,6 @@ def __init__( else: self._s3_client = None - self.document_model = ( - api_sanitize(self.document_model) # type: ignore - if self.document_model is not None - else None # type: ignore - ) - @property def session(self) -> requests.Session: if not self._session: @@ -1057,10 +1054,8 @@ def _convert_to_model(self, data: list[dict]): (list[MPDataDoc]): List of MPDataDoc objects """ - raw_doc_list = [self.document_model.model_validate(d) for d in data] # type: ignore - - if len(raw_doc_list) > 0: - data_model, set_fields, _ = self._generate_returned_model(raw_doc_list[0]) + if len(data) > 0: + data_model, set_fields, _ = self._generate_returned_model(data[0]) data = [ data_model( @@ -1070,33 +1065,37 @@ def _convert_to_model(self, data: list[dict]): if field in set_fields } ) - for raw_doc in raw_doc_list + for raw_doc in data ] return data - def _generate_returned_model(self, doc): + def _generate_returned_model( + self, doc: dict[str, Any] + ) -> tuple[BaseModel, list[str], list[str]]: model_fields = self.document_model.model_fields - - set_fields = doc.model_fields_set + set_fields = [k for k in doc if k in model_fields] unset_fields = [field for field in model_fields if field not in set_fields] # Update with locals() from external module if needed - other_vars = {} if any( + isinstance(field_meta.annotation, ForwardRef) + for field_meta in model_fields.values() + ) or any( isinstance(typ, ForwardRef) for field_meta in model_fields.values() for typ in get_args(field_meta.annotation) ): - other_vars = vars(import_module(self.document_model.__module__)) - - include_fields = { - name: ( - model_fields[name].annotation, - model_fields[name], + vars(import_module(self.document_model.__module__)) + + include_fields: dict[str, tuple[type, FieldInfo]] = {} + for name in set_fields: + field_copy = model_fields[name]._copy() + field_copy.default = None + include_fields[name] = ( + Optional[model_fields[name].annotation], + field_copy, ) - for name in set_fields - } data_model = create_model( # type: ignore "MPDataDoc", @@ -1104,10 +1103,18 @@ def _generate_returned_model(self, doc): # TODO fields_not_requested is not the same as unset_fields # i.e. field could be requested but not available in the raw doc fields_not_requested=(list[str], unset_fields), - __base__=self.document_model, + __doc__=".".join( + [ + getattr(self.document_model, k, "") + for k in ("__module__", "__name__") + ] + ), + __module__=self.document_model.__module__, ) - if other_vars: - data_model.model_rebuild(_types_namespace=other_vars) + # if other_vars: + # data_model.model_rebuild(_types_namespace=other_vars) + + orig_rester_name = self.document_model.__name__ def new_repr(self) -> str: extra = ",\n".join( @@ -1116,7 +1123,7 @@ def new_repr(self) -> str: if n == "fields_not_requested" or n in set_fields ) - s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501 + s = f"\033[4m\033[1m{self.__class__.__name__}<{orig_rester_name}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501 return s def new_str(self) -> str: @@ -1230,8 +1237,14 @@ def get_data_by_id( stacklevel=2, ) - if self.primary_key in ["material_id", "task_id"]: - validate_ids([document_id]) + if self.primary_key in [ + "material_id", + "task_id", + "battery_id", + "spectrum_id", + "thermo_id", + ]: + document_id = validate_ids([document_id])[0] if isinstance(fields, str): # pragma: no cover fields = (fields,) # type: ignore diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 1bba9195..83022e0f 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -1,14 +1,7 @@ from __future__ import annotations -import re -from functools import cache -from typing import Optional, get_args - -from maggma.utils import get_flat_models_from_model +from emmet.core.mpid_ext import validate_identifier from monty.json import MSONable -from pydantic import BaseModel -from pydantic._internal._utils import lenient_issubclass -from pydantic.fields import FieldInfo from mp_api.client.core.settings import MAPIClientSettings @@ -31,74 +24,10 @@ def validate_ids(id_list: list[str]): " data for all IDs and filter locally." ) - pattern = "(mp|mvc|mol|mpcule)-.*" - - for entry in id_list: - if re.match(pattern, entry) is None: - raise ValueError(f"{entry} is not formatted correctly!") - - return id_list - - -@cache -def api_sanitize( - pydantic_model: BaseModel, - fields_to_leave: list[str] | None = None, - allow_dict_msonable=False, -): - """Function to clean up pydantic models for the API by: - 1.) Making fields optional - 2.) Allowing dictionaries in-place of the objects for MSONable quantities. - - WARNING: This works in place, so it mutates the model and all sub-models - - Args: - pydantic_model (BaseModel): Pydantic model to alter - fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field". - Defaults to None. - allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities. - Defaults to False - """ - models = [ - model - for model in get_flat_models_from_model(pydantic_model) - if issubclass(model, BaseModel) - ] # type: list[BaseModel] - - fields_to_leave = fields_to_leave or [] - fields_tuples = [f.split(".") for f in fields_to_leave] - assert all(len(f) == 2 for f in fields_tuples) - - for model in models: - model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]} - for name, field in model.model_fields.items(): - field_type = field.annotation - - if field_type is not None and allow_dict_msonable: - if lenient_issubclass(field_type, MSONable): - field_type = allow_msonable_dict(field_type) - else: - for sub_type in get_args(field_type): - if lenient_issubclass(sub_type, MSONable): - allow_msonable_dict(sub_type) - - if name not in model_fields_to_leave: - new_field = FieldInfo.from_annotated_attribute( - Optional[field_type], None - ) - - for attr in ( - "json_schema_extra", - "exclude", - ): - if (val := getattr(field, attr)) is not None: - setattr(new_field, attr, val) - - model.model_fields[name] = new_field - - model.model_rebuild(force=True) - - return pydantic_model + # TODO: after the transition to AlphaID in the document models, + # The following line should be changed to + # return [validate_identifier(idx,serialize=True) for idx in id_list] + return [str(validate_identifier(idx)) for idx in id_list] def allow_msonable_dict(monty_cls: type[MSONable]): diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 064550b9..347eba64 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -12,7 +12,7 @@ from emmet.core.mpid import MPID from emmet.core.settings import EmmetSettings from emmet.core.tasks import TaskDoc -from emmet.core.thermo import ThermoType +from emmet.core.types.enums import ThermoType from emmet.core.vasp.calc_types import CalcType from monty.json import MontyDecoder from packaging import version @@ -1341,7 +1341,7 @@ def get_charge_density_from_task_id( decoder = MontyDecoder().decode if self.monty_decode else json.loads kwargs = dict( bucket="materialsproject-parsed", - key=f"chgcars/{str(task_id)}.json.gz", + key=f"chgcars/{validate_ids([task_id])[0]}.json.gz", decoder=decoder, ) chgcar = self.materials.tasks._query_open_data(**kwargs)[0] diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index d7284712..21503b2e 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -237,7 +237,7 @@ def get_bandstructure_from_task_id(self, task_id: str): decoder = MontyDecoder().decode if self.monty_decode else json.loads result = self._query_open_data( bucket="materialsproject-parsed", - key=f"bandstructures/{task_id}.json.gz", + key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", decoder=decoder, )[0] @@ -433,7 +433,7 @@ def get_dos_from_task_id(self, task_id: str): decoder = MontyDecoder().decode if self.monty_decode else json.loads result = self._query_open_data( bucket="materialsproject-parsed", - key=f"dos/{task_id}.json.gz", + key=f"dos/{validate_ids([task_id])[0]}.json.gz", decoder=decoder, )[0] diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 4473900d..1f242c87 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -3,7 +3,8 @@ from collections import defaultdict import numpy as np -from emmet.core.thermo import ThermoDoc, ThermoType +from emmet.core.thermo import ThermoDoc +from emmet.core.types.enums import ThermoType from monty.json import MontyDecoder from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.core import Element diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index b4ee1d82..cd4d1647 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -1,11 +1,16 @@ from __future__ import annotations -from emmet.core.xas import Edge, Type, XASDoc +from typing import TYPE_CHECKING + +from emmet.core.xas import XASDoc from pymatgen.core.periodic_table import Element from mp_api.client.core import BaseRester from mp_api.client.core.utils import validate_ids +if TYPE_CHECKING: + from emmet.core.types.enums import XasEdge, XasType + class XASRester(BaseRester[XASDoc]): suffix = "materials/xas" @@ -14,13 +19,13 @@ class XASRester(BaseRester[XASDoc]): def search( self, - edge: Edge | None = None, + edge: XasEdge | None = None, absorbing_element: Element | None = None, formula: str | None = None, chemsys: str | list[str] | None = None, elements: list[str] | None = None, material_ids: list[str] | None = None, - spectrum_type: Type | None = None, + spectrum_type: XasType | None = None, spectrum_ids: str | list[str] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, @@ -30,7 +35,7 @@ def search( """Query core XAS docs using a variety of search criteria. Arguments: - edge (Edge): The absorption edge (e.g. K, L2, L3, L2,3). + edge (XasEdge): The absorption edge (e.g. K, L2, L3, L2,3). absorbing_element (Element): The absorbing element. formula (str): A formula including anonymized formula or wild cards (e.g., Fe2O3, ABO3, Si*). @@ -39,7 +44,7 @@ def search( elements (List[str]): A list of elements. material_ids (str, List[str]): A single Material ID string or list of strings (e.g., mp-149, [mp-149, mp-13]). - spectrum_type (Type): Spectrum type (e.g. EXAFS, XAFS, or XANES). + spectrum_type (XasType): Spectrum type (e.g. EXAFS, XAFS, or XANES). spectrum_ids (str, List[str]): A single Spectrum ID string or list of strings (e.g., mp-149-XANES-Li-K, [mp-149-XANES-Li-K, mp-13-XANES-Li-K]). num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. diff --git a/tests/materials/test_thermo.py b/tests/materials/test_thermo.py index acb58d8b..9db613f8 100644 --- a/tests/materials/test_thermo.py +++ b/tests/materials/test_thermo.py @@ -2,7 +2,7 @@ from core_function import client_search_testing import pytest -from emmet.core.thermo import ThermoType +from emmet.core.types.enums import ThermoType from pymatgen.analysis.phase_diagram import PhaseDiagram from mp_api.client.routes.materials.thermo import ThermoRester diff --git a/tests/materials/test_xas.py b/tests/materials/test_xas.py index d03af7ab..980e5d9d 100644 --- a/tests/materials/test_xas.py +++ b/tests/materials/test_xas.py @@ -2,7 +2,7 @@ from core_function import client_search_testing import pytest -from emmet.core.xas import Edge, Type +from emmet.core.types.enums import XasEdge, XasType from pymatgen.core.periodic_table import Element from mp_api.client.routes.materials.xas import XASRester @@ -33,8 +33,8 @@ def rester(): } # type: dict custom_field_tests = { - "edge": Edge.L2_3, - "spectrum_type": Type.EXAFS, + "edge": XasEdge.L2_3, + "spectrum_type": XasType.EXAFS, "absorbing_element": Element("Ce"), "required_elements": [Element("Ce")], "formula": "Ce(WO4)2", diff --git a/tests/test_client.py b/tests/test_client.py index 25bf1812..3b2f64da 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -67,6 +67,10 @@ def test_generic_get_methods(rester): use_document_model=True, ) + docs_check = lambda _docs: all( + rester.document_model.__module__ == _doc.__module__ for _doc in _docs + ) + if name not in ignore_generic: key = rester.primary_key if name not in key_only_resters: @@ -74,12 +78,22 @@ def test_generic_get_methods(rester): key = rester.available_fields[0] doc = rester._query_resource_data({"_limit": 1}, fields=[key])[0] - assert isinstance(doc, rester.document_model) + assert docs_check([doc]) if name not in search_only_resters: - doc = rester.get_data_by_id(doc.model_dump()[key], fields=[key]) - assert isinstance(doc, rester.document_model) + docs = rester.search( + **{key + "s": [doc.model_dump()[key]]}, fields=[key] + ) + assert docs_check(docs) elif name not in special_resters: - doc = rester.get_data_by_id(key_only_resters[name], fields=[key]) - assert isinstance(doc, rester.document_model) + search_method = "search" + if name == "materials_robocrys": + search_method += "_docs" + docs = getattr(rester, search_method)( + **{key + "s": [key_only_resters[name]]}, fields=[key] + ) + with pytest.warns(DeprecationWarning, match="get_data_by_id is deprecated"): + _ = rester.get_data_by_id(key_only_resters[name], fields=[key]) + + assert docs_check(docs) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index 8eb55ef0..0ac7040c 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -6,7 +6,7 @@ import numpy as np import pytest from emmet.core.tasks import TaskDoc -from emmet.core.thermo import ThermoType +from emmet.core.types.enums import ThermoType from emmet.core.vasp.calc_types import CalcType from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import IonEntry, PourbaixDiagram, PourbaixEntry @@ -303,7 +303,7 @@ def test_get_charge_density_from_material_id(self, mpr): "mp-149", inc_task_doc=True ) assert isinstance(chgcar, Chgcar) - assert isinstance(task_doc, TaskDoc) + assert isinstance(TaskDoc.model_validate(task_doc.model_dump()), TaskDoc) def test_get_charge_density_from_task_id(self, mpr): chgcar = mpr.get_charge_density_from_task_id("mp-2246557") @@ -313,7 +313,7 @@ def test_get_charge_density_from_task_id(self, mpr): "mp-2246557", inc_task_doc=True ) assert isinstance(chgcar, Chgcar) - assert isinstance(task_doc, TaskDoc) + assert isinstance(TaskDoc.model_validate(task_doc.model_dump()), TaskDoc) def test_get_wulff_shape(self, mpr): ws = mpr.get_wulff_shape("mp-126")