Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
from typing import (
TYPE_CHECKING,
ForwardRef,
Generic,
Optional,
TypeVar,
get_args,
)
from urllib.parse import quote, urljoin
Expand Down Expand Up @@ -65,10 +63,8 @@

SETTINGS = MAPIClientSettings() # type: ignore

T = TypeVar("T")


class BaseRester(Generic[T]):
class BaseRester:
"""Base client class with core stubs."""

suffix: str = ""
Expand Down Expand Up @@ -140,15 +136,8 @@ def __init__(
if not self.endpoint.endswith("/"):
self.endpoint += "/"

if session:
self._session = session
else:
self._session = None # type: ignore

if s3_client:
self._s3_client = s3_client
else:
self._s3_client = None
self._session = session
self._s3_client = s3_client

@property
def session(self) -> requests.Session:
Expand Down Expand Up @@ -596,7 +585,7 @@ def _submit_requests( # noqa
url: url used to make request
use_document_model: if None, will defer to the self.use_document_model attribute
parallel_param: parameter to parallelize requests with
num_chu: fieldsnky: Maximum number of chunks of data to yield. None will yield all possible.
num_chunks: Maximum number of chunks of data to yield. None will yield all possible.
chunk_size: Number of data entries per chunk.
timeout: Time in seconds to wait until a request timeout error is thrown

Expand Down Expand Up @@ -1077,7 +1066,9 @@ def _generate_returned_model(
include_fields: dict[str, tuple[type, FieldInfo]] = {}
for name in set_fields:
field_copy = model_fields[name]._copy()
field_copy.default = None
if not field_copy.default_factory:
# Fields with a default_factory cannot also have a default in pydantic>=2.12.3
field_copy.default = None
include_fields[name] = (
Optional[model_fields[name].annotation],
field_copy,
Expand All @@ -1097,8 +1088,6 @@ def _generate_returned_model(
),
__module__=self.document_model.__module__,
)
# if other_vars:
# data_model.model_rebuild(_types_namespace=other_vars)

orig_rester_name = self.document_model.__name__

Expand Down Expand Up @@ -1151,7 +1140,7 @@ def _query_resource_data(
suburl: str | None = None,
use_document_model: bool | None = None,
timeout: int | None = None,
) -> list[T] | list[dict]:
) -> list[BaseModel] | list[dict]:
"""Query the endpoint for a list of documents without associated meta information. Only
returns a single page of results.

Expand Down Expand Up @@ -1181,7 +1170,7 @@ def _search(
all_fields: bool = True,
fields: list[str] | None = None,
**kwargs,
) -> list[T] | list[dict]:
) -> list[BaseModel] | list[dict]:
"""A generic search method to retrieve documents matching specific parameters.

Arguments:
Expand Down Expand Up @@ -1216,7 +1205,7 @@ def get_data_by_id(
self,
document_id: str,
fields: list[str] | None = None,
) -> T | dict:
) -> BaseModel | dict:
warnings.warn(
"get_data_by_id is deprecated and will be removed soon. Please use the search method instead.",
DeprecationWarning,
Expand Down Expand Up @@ -1251,7 +1240,7 @@ def _get_all_documents(
fields=None,
chunk_size=1000,
num_chunks=None,
) -> list[T] | list[dict]:
) -> list[BaseModel] | list[dict]:
"""Iterates over pages until all documents are retrieved. Displays
progress using tqdm. This method is designed to give a common
implementation for the search_* methods on various endpoints. See
Expand Down
9 changes: 5 additions & 4 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,10 +1203,11 @@ def get_entries_in_chemsys(

elements_set = set(elements) # remove duplicate elements

all_chemsyses = []
for i in range(len(elements_set)):
for els in itertools.combinations(elements_set, i + 1):
all_chemsyses.append("-".join(sorted(els)))
all_chemsyses = [
"-".join(sorted(els))
for i in range(len(elements_set))
for els in itertools.combinations(elements_set, i + 1)
]

entries = []

Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/_general_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mp_api.client.core import BaseRester


class GeneralStoreRester(BaseRester[GeneralStoreDoc]): # pragma: no cover
class GeneralStoreRester(BaseRester): # pragma: no cover
suffix = "_general_store"
document_model = GeneralStoreDoc # type: ignore
primary_key = "submission_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mp_api.client.core import BaseRester


class MessagesRester(BaseRester[MessagesDoc]): # pragma: no cover
class MessagesRester(BaseRester): # pragma: no cover
suffix = "_messages"
document_model = MessagesDoc # type: ignore
primary_key = "title"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/_user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mp_api.client.core import BaseRester


class UserSettingsRester(BaseRester[UserSettingsDoc]): # pragma: no cover
class UserSettingsRester(BaseRester): # pragma: no cover
suffix = "_user_settings"
document_model = UserSettingsDoc # type: ignore
primary_key = "consumer_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class AbsorptionRester(BaseRester[AbsorptionDoc]):
class AbsorptionRester(BaseRester):
suffix = "materials/absorption"
document_model = AbsorptionDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/alloys.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class AlloysRester(BaseRester[AlloyPairDoc]):
class AlloysRester(BaseRester):
suffix = "materials/alloys"
document_model = AlloyPairDoc # type: ignore
primary_key = "pair_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class BondsRester(BaseRester[BondingDoc]):
class BondsRester(BaseRester):
suffix = "materials/bonds"
document_model = BondingDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/chemenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mp_api.client.core.utils import validate_ids


class ChemenvRester(BaseRester[ChemEnvDoc]):
class ChemenvRester(BaseRester):
suffix = "materials/chemenv"
document_model = ChemEnvDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/dielectric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class DielectricRester(BaseRester[DielectricDoc]):
class DielectricRester(BaseRester):
suffix = "materials/dielectric"
document_model = DielectricDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/doi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class DOIRester(BaseRester[DOIDoc]):
class DOIRester(BaseRester):
suffix = "doi"
document_model = DOIDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class ElasticityRester(BaseRester[ElasticityDoc]):
class ElasticityRester(BaseRester):
suffix = "materials/elasticity"
document_model = ElasticityDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mp_api.client.core.utils import validate_ids


class ElectronicStructureRester(BaseRester[ElectronicStructureDoc]):
class ElectronicStructureRester(BaseRester):
suffix = "materials/electronic_structure"
document_model = ElectronicStructureDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class EOSRester(BaseRester[EOSDoc]):
class EOSRester(BaseRester):
suffix = "materials/eos"
document_model = EOSDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/grain_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class GrainBoundaryRester(BaseRester[GrainBoundaryDoc]):
class GrainBoundaryRester(BaseRester):
suffix = "materials/grain_boundaries"
document_model = GrainBoundaryDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/magnetism.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mp_api.client.core.utils import validate_ids


class MagnetismRester(BaseRester[MagnetismDoc]):
class MagnetismRester(BaseRester):
suffix = "materials/magnetism"
document_model = MagnetismDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
_EMMET_SETTINGS = EmmetSettings() # type: ignore


class MaterialsRester(BaseRester[MaterialsDoc]):
class MaterialsRester(BaseRester):
suffix = "materials/core"
document_model = MaterialsDoc # type: ignore
supports_versions = True
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/oxidation_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class OxidationStatesRester(BaseRester[OxidationStateDoc]):
class OxidationStatesRester(BaseRester):
suffix = "materials/oxidation_states"
document_model = OxidationStateDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/phonon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mp_api.client.core.utils import validate_ids


class PhononRester(BaseRester[PhononBSDOSDoc]):
class PhononRester(BaseRester):
suffix = "materials/phonon"
document_model = PhononBSDOSDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/piezo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class PiezoRester(BaseRester[PiezoelectricDoc]):
class PiezoRester(BaseRester):
suffix = "materials/piezoelectric"
document_model = PiezoelectricDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mp_api.client.core.utils import validate_ids


class ProvenanceRester(BaseRester[ProvenanceDoc]):
class ProvenanceRester(BaseRester):
suffix = "materials/provenance"
document_model = ProvenanceDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/robocrys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mp_api.client.core.utils import validate_ids


class RobocrysRester(BaseRester[RobocrystallogapherDoc]):
class RobocrysRester(BaseRester):
suffix = "materials/robocrys"
document_model = RobocrystallogapherDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mp_api.client.core.utils import validate_ids


class SimilarityRester(BaseRester[SimilarityDoc]):
class SimilarityRester(BaseRester):
suffix = "materials/similarity"
document_model = SimilarityDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/substrates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mp_api.client.core import BaseRester


class SubstratesRester(BaseRester[SubstratesDoc]):
class SubstratesRester(BaseRester):
suffix = "materials/substrates"
document_model = SubstratesDoc # type: ignore
primary_key = "film_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mp_api.client.core.utils import validate_ids


class SummaryRester(BaseRester[SummaryDoc]):
class SummaryRester(BaseRester):
suffix = "materials/summary"
document_model = SummaryDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/surface_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mp_api.client.core.utils import validate_ids


class SurfacePropertiesRester(BaseRester[SurfacePropDoc]):
class SurfacePropertiesRester(BaseRester):
suffix = "materials/surface_properties"
document_model = SurfacePropDoc # type: ignore
primary_key = "material_id"
Expand Down
2 changes: 1 addition & 1 deletion mp_api/client/routes/materials/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mp_api.client.core import BaseRester, MPRestError


class SynthesisRester(BaseRester[SynthesisSearchResultModel]):
class SynthesisRester(BaseRester):
suffix = "materials/synthesis"
document_model = SynthesisSearchResultModel # type: ignore

Expand Down
20 changes: 12 additions & 8 deletions mp_api/client/routes/materials/tasks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

from emmet.core.tasks import TaskDoc
from emmet.core.tasks import CoreTaskDoc

from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core.utils import validate_ids

if TYPE_CHECKING:
from pydantic import BaseModel

class TaskRester(BaseRester[TaskDoc]):
suffix = "materials/tasks"
document_model = TaskDoc # type: ignore
primary_key = "task_id"

class TaskRester(BaseRester):
suffix: str = "materials/tasks"
document_model: type[BaseModel] = CoreTaskDoc # type: ignore
primary_key: str = "task_id"

def get_trajectory(self, task_id):
"""Returns a Trajectory object containing the geometry of the
Expand Down Expand Up @@ -44,7 +48,7 @@ def search(
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
) -> list[TaskDoc] | list[dict]:
) -> list[CoreTaskDoc] | list[dict]:
"""Query core task docs using a variety of search criteria.

Arguments:
Expand All @@ -58,11 +62,11 @@ def search(
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk. Max size is 100.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in TaskDoc to return data for.
fields (List[str]): List of fields in CoreTaskDoc to return data for.
Default is material_id, last_updated, and formula_pretty if all_fields is False.

Returns:
([TaskDoc], [dict]) List of task documents or dictionaries.
([CoreTaskDoc], [dict]) List of task documents or dictionaries.
"""
query_params = {} # type: dict

Expand Down
Loading