Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
maggma Copyright (c) 2017, The Regents of the University of
Copyright (c) 2017, The Regents of the University of
California, through Lawrence Berkeley National Laboratory (subject
to receipt of any required approvals from the U.S. Dept. of Energy).
All rights reserved.
Expand Down
44 changes: 15 additions & 29 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import inspect
import itertools
import json
import os
import platform
import sys
Expand All @@ -29,9 +28,7 @@
from urllib.parse import quote, urljoin

import requests
from bson import json_util
from emmet.core.utils import jsanitize
from monty.json import MontyDecoder
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
Expand All @@ -40,7 +37,7 @@
from urllib3.util.retry import Retry

from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import api_sanitize, validate_ids
from mp_api.client.core.utils import api_sanitize, load_json, validate_ids

try:
import boto3
Expand Down Expand Up @@ -270,11 +267,7 @@ def _post_resource(
response = self.session.post(url, json=payload, verify=True, params=params)

if response.status_code == 200:
if self.monty_decode:
data = json.loads(response.text, cls=MontyDecoder)
else:
data = json.loads(response.text)

data = load_json(response.text, deser=self.monty_decode)
if self.document_model and use_document_model:
if isinstance(data["data"], dict):
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
Expand All @@ -287,7 +280,7 @@ def _post_resource(

else:
try:
data = json.loads(response.text)["detail"]
data = load_json(response.text)["detail"]
except (JSONDecodeError, KeyError):
data = f"Response {response.text}"
if isinstance(data, str):
Expand Down Expand Up @@ -342,11 +335,7 @@ def _patch_resource(
response = self.session.patch(url, json=payload, verify=True, params=params)

if response.status_code == 200:
if self.monty_decode:
data = json.loads(response.text, cls=MontyDecoder)
else:
data = json.loads(response.text)

data = load_json(response.text, deser=self.monty_decode)
if self.document_model and use_document_model:
if isinstance(data["data"], dict):
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
Expand All @@ -359,7 +348,7 @@ def _patch_resource(

else:
try:
data = json.loads(response.text)["detail"]
data = load_json(response.text)["detail"]
except (JSONDecodeError, KeyError):
data = f"Response {response.text}"
if isinstance(data, str):
Expand All @@ -384,18 +373,24 @@ def _query_open_data(
self,
bucket: str,
key: str,
decoder: Callable,
decoder: Callable | None = None,
) -> tuple[list[dict] | list[bytes], int]:
"""Query and deserialize Materials Project AWS open data s3 buckets.

Args:
bucket (str): Materials project bucket name
key (str): Key for file including all prefixes
decoder(Callable): Callable used to deserialize data
decoder(Callable or None): Callable used to deserialize data.
Defaults to mp_api.core.utils.load_json

Returns:
dict: MontyDecoded data
"""
if not decoder:

def decoder(x):
return load_json(x, deser=self.monty_decode)

file = open(
f"s3://{bucket}/{key}",
encoding="utf-8",
Expand Down Expand Up @@ -527,16 +522,11 @@ def _query_resource(
"Ignoring `fields` argument: All fields are always included when no query is provided."
)

decoder = (
MontyDecoder().decode if self.monty_decode else json_util.loads
)

# Multithreaded function inputs
s3_params_list = {
key: {
"bucket": bucket,
"key": key,
"decoder": decoder,
}
for key in keys
}
Expand Down Expand Up @@ -1013,11 +1003,7 @@ def _submit_request_and_process(
)

if response.status_code == 200:
if self.monty_decode:
data = json.loads(response.text, cls=MontyDecoder)
else:
data = json.loads(response.text)

data = load_json(response.text, deser=self.monty_decode)
# other sub-urls may use different document models
# the client does not handle this in a particularly smart way currently
if self.document_model and use_document_model:
Expand All @@ -1029,7 +1015,7 @@ def _submit_request_and_process(

else:
try:
data = json.loads(response.text)["detail"]
data = load_json(response.text)["detail"]
except (JSONDecodeError, KeyError):
data = f"Response {response.text}"
if isinstance(data, str):
Expand Down
13 changes: 11 additions & 2 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,24 @@
from functools import cache
from typing import Optional, get_args

from maggma.utils import get_flat_models_from_model
from monty.json import MSONable
import orjson
from emmet.core.utils import get_flat_models_from_model
from monty.json import MontyDecoder, MSONable
from pydantic import BaseModel
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo

from mp_api.client.core.settings import MAPIClientSettings


def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"):
"""Utility to load json in consistent manner."""
data = orjson.loads(
json_like if isinstance(json_like, bytes) else json_like.encode(encoding)
)
return MontyDecoder().process_decoded(data) if deser else data


def validate_ids(id_list: list[str]):
"""Function to validate material and task IDs.

Expand Down
12 changes: 5 additions & 7 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import itertools
import json
import os
import warnings
from functools import cache, lru_cache
from json import loads
from typing import TYPE_CHECKING

from emmet.core.electronic_structure import BSPathType
Expand All @@ -14,7 +12,6 @@
from emmet.core.tasks import TaskDoc
from emmet.core.thermo import ThermoType
from emmet.core.vasp.calc_types import CalcType
from monty.json import MontyDecoder
from packaging import version
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.analysis.pourbaix_diagram import IonEntry
Expand All @@ -27,14 +24,15 @@

from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import validate_ids
from mp_api.client.core.utils import load_json, validate_ids
from mp_api.client.routes import GeneralStoreRester, MessagesRester, UserSettingsRester
from mp_api.client.routes.materials import (
AbsorptionRester,
AlloysRester,
BandStructureRester,
BondsRester,
ChemenvRester,
ConversionElectrodeRester,
DielectricRester,
DOIRester,
DosRester,
Expand Down Expand Up @@ -99,6 +97,7 @@ class MPRester:
robocrys: RobocrysRester
synthesis: SynthesisRester
insertion_electrodes: ElectrodeRester
conversion_electrodes: ConversionElectrodeRester
electronic_structure: ElectronicStructureRester
electronic_structure_bandstructure: BandStructureRester
electronic_structure_dos: DosRester
Expand Down Expand Up @@ -1338,11 +1337,10 @@ def get_charge_density_from_task_id(
Returns:
(Chgcar, (Chgcar, TaskDoc | dict), None): Pymatgen Chgcar object, or tuple with object and TaskDoc
"""
decoder = MontyDecoder().decode if self.monty_decode else json.loads
kwargs = dict(
bucket="materialsproject-parsed",
key=f"chgcars/{str(task_id)}.json.gz",
decoder=decoder,
decoder=lambda x: load_json(x, deser=self.monty_decode),
)
chgcar = self.materials.tasks._query_open_data(**kwargs)[0]
if not chgcar:
Expand Down Expand Up @@ -1476,7 +1474,7 @@ def _check_nomad_exist(url) -> bool:
response = get(url=url)
if response.status_code != 200:
return False
content = loads(response.text)
content = load_json(response.text)
if content["pagination"]["total"] == 0:
return False
return True
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .dielectric import DielectricRester
from .doi import DOIRester
from .elasticity import ElasticityRester
from .electrodes import ElectrodeRester
from .electrodes import ConversionElectrodeRester, ElectrodeRester
from .electronic_structure import (
BandStructureRester,
DosRester,
Expand Down
47 changes: 40 additions & 7 deletions mp_api/client/routes/materials/electrodes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

import warnings
from collections import defaultdict

from emmet.core.electrode import InsertionElectrodeDoc
from emmet.core.electrode import ConversionElectrodeDoc, InsertionElectrodeDoc
from pymatgen.core.periodic_table import Element

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids


class ElectrodeRester(BaseRester[InsertionElectrodeDoc]):
suffix = "materials/insertion_electrodes"
document_model = InsertionElectrodeDoc # type: ignore
class BaseElectrodeRester(BaseRester):
primary_key = "battery_id"
_exclude_search_fields: list[str] | None = None

def search( # pragma: ignore
self,
Expand All @@ -38,7 +38,7 @@ def search( # pragma: ignore
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
) -> list[InsertionElectrodeDoc] | list[dict]:
) -> list[InsertionElectrodeDoc | ConversionElectrodeDoc] | list[dict]:
"""Query using a variety of search criteria.

Arguments:
Expand Down Expand Up @@ -74,11 +74,11 @@ def search( # pragma: ignore
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in InsertionElectrodeDoc to return data for.
fields (List[str]): List of fields in InsertionElectrodeDoc or ConversionElectrodeDoc to return data for.
Default is battery_id and last_updated if all_fields is False.

Returns:
([InsertionElectrodeDoc], [dict]) List of insertion electrode documents or dictionaries.
([InsertionElectrodeDoc or ConversionElectrodeDoc], [dict]) List of insertion/conversion electrode documents or dictionaries.
"""
query_params = defaultdict(dict) # type: dict

Expand Down Expand Up @@ -134,10 +134,43 @@ def search( # pragma: ignore
else:
query_params.update({param: value})

excluded_fields = self._exclude_search_fields or []
ignored_fields = {
entry
for entry in excluded_fields
if query_params.pop(entry, None) is not None
}
if ignored_fields:
warnings.warn(
f"Ignoring fields {', '.join(ignored_fields)} which are not valid options for {self.__class__.__name__}"
)

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

return super()._search(**query_params)


class ElectrodeRester(BaseElectrodeRester):
"""Search insertion electrode documents."""

suffix = "materials/insertion_electrodes"
document_model = InsertionElectrodeDoc # type: ignore


class ConversionElectrodeRester(BaseElectrodeRester):
"""Search conversion electrode documents."""

suffix = "materials/conversion_electrodes"
document_model = ConversionElectrodeDoc # type: ignore
# TODO: formula, chemsys, and elements do not appear to work in the API
_exclude_search_fields = [
"formula",
"chemsys",
"elements",
"stability_charge",
"stability_discharge",
]
6 changes: 0 additions & 6 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
import warnings
from collections import defaultdict

Expand All @@ -9,7 +8,6 @@
DOSProjectionType,
ElectronicStructureDoc,
)
from monty.json import MontyDecoder
from pymatgen.analysis.magnetism.analyzer import Ordering
from pymatgen.core.periodic_table import Element
from pymatgen.electronic_structure.core import OrbitalType, Spin
Expand Down Expand Up @@ -234,11 +232,9 @@ def get_bandstructure_from_task_id(self, task_id: str):
Returns:
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
"""
decoder = MontyDecoder().decode if self.monty_decode else json.loads
result = self._query_open_data(
bucket="materialsproject-parsed",
key=f"bandstructures/{task_id}.json.gz",
decoder=decoder,
)[0]

if result:
Expand Down Expand Up @@ -430,11 +426,9 @@ def get_dos_from_task_id(self, task_id: str):
Returns:
bandstructure (CompleteDos): CompleteDos object
"""
decoder = MontyDecoder().decode if self.monty_decode else json.loads
result = self._query_open_data(
bucket="materialsproject-parsed",
key=f"dos/{task_id}.json.gz",
decoder=decoder,
)[0]

if result:
Expand Down
3 changes: 3 additions & 0 deletions mp_api/client/routes/materials/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BandStructureRester,
BondsRester,
ChemenvRester,
ConversionElectrodeRester,
DielectricRester,
DosRester,
ElasticityRester,
Expand Down Expand Up @@ -62,6 +63,7 @@ class MaterialsRester(BaseRester[MaterialsDoc]):
"robocrys",
"synthesis",
"insertion_electrodes",
"conversion_electrodes",
"electronic_structure",
"electronic_structure_bandstructure",
"electronic_structure_dos",
Expand Down Expand Up @@ -92,6 +94,7 @@ class MaterialsRester(BaseRester[MaterialsDoc]):
robocrys: RobocrysRester
synthesis: SynthesisRester
insertion_electrodes: ElectrodeRester
conversion_electrodes: ConversionElectrodeRester
electronic_structure: ElectronicStructureRester
electronic_structure_bandstructure: BandStructureRester
electronic_structure_dos: DosRester
Expand Down
Loading
Loading