Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 43 additions & 30 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TYPE_CHECKING,
ForwardRef,
Generic,
Optional,
TypeVar,
get_args,
)
Expand All @@ -40,7 +41,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import api_sanitize, validate_ids
from mp_api.client.core.utils import validate_ids

try:
import boto3
Expand All @@ -57,6 +58,8 @@
if TYPE_CHECKING:
from typing import Any, Callable

from pydantic.fields import FieldInfo

try:
__version__ = version("mp_api")
except PackageNotFoundError: # pragma: no cover
Expand Down Expand Up @@ -150,12 +153,6 @@ def __init__(
else:
self._s3_client = None

self.document_model = (
api_sanitize(self.document_model) # type: ignore
if self.document_model is not None
else None # type: ignore
)

@property
def session(self) -> requests.Session:
if not self._session:
Expand Down Expand Up @@ -1057,10 +1054,8 @@ def _convert_to_model(self, data: list[dict]):
(list[MPDataDoc]): List of MPDataDoc objects

"""
raw_doc_list = [self.document_model.model_validate(d) for d in data] # type: ignore

if len(raw_doc_list) > 0:
data_model, set_fields, _ = self._generate_returned_model(raw_doc_list[0])
if len(data) > 0:
data_model, set_fields, _ = self._generate_returned_model(data[0])

data = [
data_model(
Expand All @@ -1070,44 +1065,56 @@ def _convert_to_model(self, data: list[dict]):
if field in set_fields
}
)
for raw_doc in raw_doc_list
for raw_doc in data
]

return data

def _generate_returned_model(self, doc):
def _generate_returned_model(
self, doc: dict[str, Any]
) -> tuple[BaseModel, list[str], list[str]]:
model_fields = self.document_model.model_fields

set_fields = doc.model_fields_set
set_fields = [k for k in doc if k in model_fields]
unset_fields = [field for field in model_fields if field not in set_fields]

# Update with locals() from external module if needed
other_vars = {}
if any(
isinstance(field_meta.annotation, ForwardRef)
for field_meta in model_fields.values()
) or any(
isinstance(typ, ForwardRef)
for field_meta in model_fields.values()
for typ in get_args(field_meta.annotation)
):
other_vars = vars(import_module(self.document_model.__module__))

include_fields = {
name: (
model_fields[name].annotation,
model_fields[name],
vars(import_module(self.document_model.__module__))

include_fields: dict[str, tuple[type, FieldInfo]] = {}
for name in set_fields:
field_copy = model_fields[name]._copy()
field_copy.default = None
include_fields[name] = (
Optional[model_fields[name].annotation],
field_copy,
)
for name in set_fields
}

data_model = create_model( # type: ignore
"MPDataDoc",
**include_fields,
# TODO fields_not_requested is not the same as unset_fields
# i.e. field could be requested but not available in the raw doc
fields_not_requested=(list[str], unset_fields),
__base__=self.document_model,
__doc__=".".join(
[
getattr(self.document_model, k, "")
for k in ("__module__", "__name__")
]
),
__module__=self.document_model.__module__,
)
if other_vars:
data_model.model_rebuild(_types_namespace=other_vars)
# if other_vars:
# data_model.model_rebuild(_types_namespace=other_vars)

orig_rester_name = self.document_model.__name__

def new_repr(self) -> str:
extra = ",\n".join(
Expand All @@ -1116,7 +1123,7 @@ def new_repr(self) -> str:
if n == "fields_not_requested" or n in set_fields
)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
s = f"\033[4m\033[1m{self.__class__.__name__}<{orig_rester_name}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
return s

def new_str(self) -> str:
Expand Down Expand Up @@ -1230,8 +1237,14 @@ def get_data_by_id(
stacklevel=2,
)

if self.primary_key in ["material_id", "task_id"]:
validate_ids([document_id])
if self.primary_key in [
"material_id",
"task_id",
"battery_id",
"spectrum_id",
"thermo_id",
]:
document_id = validate_ids([document_id])[0]

if isinstance(fields, str): # pragma: no cover
fields = (fields,) # type: ignore
Expand Down
81 changes: 5 additions & 76 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from __future__ import annotations

import re
from functools import cache
from typing import Optional, get_args

from maggma.utils import get_flat_models_from_model
from emmet.core.mpid_ext import validate_identifier
from monty.json import MSONable
from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo

from mp_api.client.core.settings import MAPIClientSettings

Expand All @@ -31,74 +24,10 @@ def validate_ids(id_list: list[str]):
" data for all IDs and filter locally."
)

pattern = "(mp|mvc|mol|mpcule)-.*"

for entry in id_list:
if re.match(pattern, entry) is None:
raise ValueError(f"{entry} is not formatted correctly!")

return id_list


@cache
def api_sanitize(
pydantic_model: BaseModel,
fields_to_leave: list[str] | None = None,
allow_dict_msonable=False,
):
"""Function to clean up pydantic models for the API by:
1.) Making fields optional
2.) Allowing dictionaries in-place of the objects for MSONable quantities.

WARNING: This works in place, so it mutates the model and all sub-models

Args:
pydantic_model (BaseModel): Pydantic model to alter
fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field".
Defaults to None.
allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities.
Defaults to False
"""
models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: list[BaseModel]

fields_to_leave = fields_to_leave or []
fields_tuples = [f.split(".") for f in fields_to_leave]
assert all(len(f) == 2 for f in fields_tuples)

for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name, field in model.model_fields.items():
field_type = field.annotation

if field_type is not None and allow_dict_msonable:
if lenient_issubclass(field_type, MSONable):
field_type = allow_msonable_dict(field_type)
else:
for sub_type in get_args(field_type):
if lenient_issubclass(sub_type, MSONable):
allow_msonable_dict(sub_type)

if name not in model_fields_to_leave:
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)

for attr in (
"json_schema_extra",
"exclude",
):
if (val := getattr(field, attr)) is not None:
setattr(new_field, attr, val)

model.model_fields[name] = new_field

model.model_rebuild(force=True)

return pydantic_model
# TODO: after the transition to AlphaID in the document models,
# The following line should be changed to
# return [validate_identifier(idx,serialize=True) for idx in id_list]
return [str(validate_identifier(idx)) for idx in id_list]


def allow_msonable_dict(monty_cls: type[MSONable]):
Expand Down
4 changes: 2 additions & 2 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from emmet.core.mpid import MPID
from emmet.core.settings import EmmetSettings
from emmet.core.tasks import TaskDoc
from emmet.core.thermo import ThermoType
from emmet.core.types.enums import ThermoType
from emmet.core.vasp.calc_types import CalcType
from monty.json import MontyDecoder
from packaging import version
Expand Down Expand Up @@ -1341,7 +1341,7 @@ def get_charge_density_from_task_id(
decoder = MontyDecoder().decode if self.monty_decode else json.loads
kwargs = dict(
bucket="materialsproject-parsed",
key=f"chgcars/{str(task_id)}.json.gz",
key=f"chgcars/{validate_ids([task_id])[0]}.json.gz",
decoder=decoder,
)
chgcar = self.materials.tasks._query_open_data(**kwargs)[0]
Expand Down
4 changes: 2 additions & 2 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_bandstructure_from_task_id(self, task_id: str):
decoder = MontyDecoder().decode if self.monty_decode else json.loads
result = self._query_open_data(
bucket="materialsproject-parsed",
key=f"bandstructures/{task_id}.json.gz",
key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz",
decoder=decoder,
)[0]

Expand Down Expand Up @@ -433,7 +433,7 @@ def get_dos_from_task_id(self, task_id: str):
decoder = MontyDecoder().decode if self.monty_decode else json.loads
result = self._query_open_data(
bucket="materialsproject-parsed",
key=f"dos/{task_id}.json.gz",
key=f"dos/{validate_ids([task_id])[0]}.json.gz",
decoder=decoder,
)[0]

Expand Down
3 changes: 2 additions & 1 deletion mp_api/client/routes/materials/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from collections import defaultdict

import numpy as np
from emmet.core.thermo import ThermoDoc, ThermoType
from emmet.core.thermo import ThermoDoc
from emmet.core.types.enums import ThermoType
from monty.json import MontyDecoder
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.core import Element
Expand Down
15 changes: 10 additions & 5 deletions mp_api/client/routes/materials/xas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

from emmet.core.xas import Edge, Type, XASDoc
from typing import TYPE_CHECKING

from emmet.core.xas import XASDoc
from pymatgen.core.periodic_table import Element

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids

if TYPE_CHECKING:
from emmet.core.types.enums import XasEdge, XasType


class XASRester(BaseRester[XASDoc]):
suffix = "materials/xas"
Expand All @@ -14,13 +19,13 @@ class XASRester(BaseRester[XASDoc]):

def search(
self,
edge: Edge | None = None,
edge: XasEdge | None = None,
absorbing_element: Element | None = None,
formula: str | None = None,
chemsys: str | list[str] | None = None,
elements: list[str] | None = None,
material_ids: list[str] | None = None,
spectrum_type: Type | None = None,
spectrum_type: XasType | None = None,
spectrum_ids: str | list[str] | None = None,
num_chunks: int | None = None,
chunk_size: int = 1000,
Expand All @@ -30,7 +35,7 @@ def search(
"""Query core XAS docs using a variety of search criteria.

Arguments:
edge (Edge): The absorption edge (e.g. K, L2, L3, L2,3).
edge (XasEdge): The absorption edge (e.g. K, L2, L3, L2,3).
absorbing_element (Element): The absorbing element.
formula (str): A formula including anonymized formula
or wild cards (e.g., Fe2O3, ABO3, Si*).
Expand All @@ -39,7 +44,7 @@ def search(
elements (List[str]): A list of elements.
material_ids (str, List[str]): A single Material ID string or list of strings
(e.g., mp-149, [mp-149, mp-13]).
spectrum_type (Type): Spectrum type (e.g. EXAFS, XAFS, or XANES).
spectrum_type (XasType): Spectrum type (e.g. EXAFS, XAFS, or XANES).
spectrum_ids (str, List[str]): A single Spectrum ID string or list of strings
(e.g., mp-149-XANES-Li-K, [mp-149-XANES-Li-K, mp-13-XANES-Li-K]).
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
Expand Down
2 changes: 1 addition & 1 deletion tests/materials/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from core_function import client_search_testing

import pytest
from emmet.core.thermo import ThermoType
from emmet.core.types.enums import ThermoType
from pymatgen.analysis.phase_diagram import PhaseDiagram

from mp_api.client.routes.materials.thermo import ThermoRester
Expand Down
6 changes: 3 additions & 3 deletions tests/materials/test_xas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from core_function import client_search_testing

import pytest
from emmet.core.xas import Edge, Type
from emmet.core.types.enums import XasEdge, XasType
from pymatgen.core.periodic_table import Element

from mp_api.client.routes.materials.xas import XASRester
Expand Down Expand Up @@ -33,8 +33,8 @@ def rester():
} # type: dict

custom_field_tests = {
"edge": Edge.L2_3,
"spectrum_type": Type.EXAFS,
"edge": XasEdge.L2_3,
"spectrum_type": XasType.EXAFS,
"absorbing_element": Element("Ce"),
"required_elements": [Element("Ce")],
"formula": "Ce(WO4)2",
Expand Down
Loading
Loading