diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index c5e7328c8..788371ed0 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -60,7 +60,7 @@ jobs: #MP_API_ENDPOINT: https://api-preview.materialsproject.org/ run: | pip install -e . - pytest -x --cov=mp_api --cov-report=xml + pytest -n auto -x --cov=mp_api --cov-report=xml - uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a3f20b1a0..d7983abcf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -default_stages: [commit] +default_stages: [pre-commit] default_install_hook_types: [pre-commit, commit-msg] ci: @@ -37,6 +37,6 @@ repos: rev: v2.2.6 hooks: - id: codespell - stages: [commit, commit-msg] + stages: [pre-commit, commit-msg] exclude_types: [json, bib, svg] args: [--ignore-words-list, "mater,fwe,te"] diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 53024ca6d..3ee0b32e4 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -15,10 +15,17 @@ from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from copy import copy from functools import cache +from importlib import import_module from importlib.metadata import PackageNotFoundError, version from json import JSONDecodeError from math import ceil -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import ( + TYPE_CHECKING, + ForwardRef, + Generic, + TypeVar, + get_args, +) from urllib.parse import quote, urljoin import requests @@ -65,7 +72,7 @@ class BaseRester(Generic[T]): """Base client class with core stubs.""" suffix: str = "" - document_model: BaseModel = None # type: ignore + document_model: type[BaseModel] | None = None supports_versions: bool = False primary_key: str = "material_id" @@ -1070,10 +1077,24 @@ def _convert_to_model(self, data: list[dict]): def _generate_returned_model(self, doc): model_fields = self.document_model.model_fields + set_fields = doc.model_fields_set unset_fields = [field for field in model_fields if field not in set_fields] + + # Update with locals() from external module if needed + other_vars = {} + if any( + isinstance(typ, ForwardRef) + for field_meta in model_fields.values() + for typ in get_args(field_meta.annotation) + ): + other_vars = vars(import_module(self.document_model.__module__)) + include_fields = { - name: (model_fields[name].annotation, model_fields[name]) + name: ( + model_fields[name].annotation, + model_fields[name], + ) for name in set_fields } @@ -1085,6 +1106,8 @@ def _generate_returned_model(self, doc): fields_not_requested=(list[str], unset_fields), __base__=self.document_model, ) + if other_vars: + data_model.model_rebuild(_types_namespace=other_vars) def new_repr(self) -> str: extra = ",\n".join( diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index fb25221ff..1bba91954 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -71,9 +71,7 @@ def api_sanitize( for model in models: model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]} - for name in model.model_fields: - field = model.model_fields[name] - field_json_extra = field.json_schema_extra + for name, field in model.model_fields.items(): field_type = field.annotation if field_type is not None and allow_dict_msonable: @@ -88,7 +86,14 @@ def api_sanitize( new_field = FieldInfo.from_annotated_attribute( Optional[field_type], None ) - new_field.json_schema_extra = field_json_extra or {} + + for attr in ( + "json_schema_extra", + "exclude", + ): + if (val := getattr(field, attr)) is not None: + setattr(new_field, attr, val) + model.model_fields[name] = new_field model.model_rebuild(force=True) diff --git a/pyproject.toml b/pyproject.toml index b75c1f67d..b257994d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ test = [ "pytest-asyncio", "pytest-cov", "pytest-mock", + "pytest-xdist", "flake8", "pycodestyle", "mypy", diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index f5519ba17..f624f0b93 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -17,9 +17,9 @@ bcrypt==4.3.0 # via paramiko bibtexparser==1.4.3 # via pymatgen -boto3==1.40.29 +boto3==1.40.31 # via maggma -botocore==1.40.29 +botocore==1.40.31 # via # boto3 # s3transfer @@ -83,7 +83,7 @@ msgpack==1.1.1 # via # maggma # mp-api (pyproject.toml) -narwhals==2.4.0 +narwhals==2.5.0 # via plotly networkx==3.5 # via pymatgen @@ -123,7 +123,7 @@ pybtex==0.25.1 # via emmet-core pycparser==2.23 # via cffi -pydantic==2.11.7 +pydantic==2.11.9 # via # emmet-core # maggma @@ -149,7 +149,7 @@ pymongo==4.10.1 # via maggma pynacl==1.6.0 # via paramiko -pyparsing==3.2.3 +pyparsing==3.2.4 # via # bibtexparser # matplotlib diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index 7b2a2d2cd..a6ef0f277 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -29,11 +29,11 @@ bibtexparser==1.4.3 # via pymatgen boltons==25.0.0 # via mpcontribs-client -boto3==1.40.29 +boto3==1.40.31 # via # maggma # mp-api (pyproject.toml) -botocore==1.40.29 +botocore==1.40.31 # via # boto3 # s3transfer @@ -76,6 +76,8 @@ docutils==0.21.2 # via sphinx emmet-core[all]==0.84.10rc2 # via mp-api (pyproject.toml) +execnet==2.1.1 + # via pytest-xdist executing==2.2.1 # via stack-data filelock==3.19.1 @@ -221,7 +223,7 @@ mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.4.0 +narwhals==2.5.0 # via plotly networkx==3.5 # via @@ -343,7 +345,7 @@ pycodestyle==2.14.0 # mp-api (pyproject.toml) pycparser==2.23 # via cffi -pydantic==2.11.7 +pydantic==2.11.9 # via # emmet-core # maggma @@ -395,7 +397,7 @@ pymongo==4.10.1 # mpcontribs-client pynacl==1.6.0 # via paramiko -pyparsing==3.2.3 +pyparsing==3.2.4 # via # bibtexparser # matplotlib @@ -405,13 +407,16 @@ pytest==8.4.2 # pytest-asyncio # pytest-cov # pytest-mock + # pytest-xdist # solvation-analysis -pytest-asyncio==1.1.0 +pytest-asyncio==1.2.0 # via mp-api (pyproject.toml) pytest-cov==7.0.0 # via mp-api (pyproject.toml) pytest-mock==3.15.0 # via mp-api (pyproject.toml) +pytest-xdist==3.8.0 + # via mp-api (pyproject.toml) python-dateutil==2.9.0.post0 # via # arrow @@ -582,7 +587,7 @@ typeguard==4.4.4 # via inflect types-python-dateutil==2.9.0.20250822 # via arrow -types-requests==2.32.4.20250809 +types-requests==2.32.4.20250913 # via mp-api (pyproject.toml) types-setuptools==80.9.0.20250822 # via mp-api (pyproject.toml) @@ -599,6 +604,7 @@ typing-extensions==4.15.0 # pydantic # pydantic-core # pydash + # pytest-asyncio # referencing # spglib # swagger-spec-validator diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index 5668df40c..272f068a4 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -17,9 +17,9 @@ bcrypt==4.3.0 # via paramiko bibtexparser==1.4.3 # via pymatgen -boto3==1.40.29 +boto3==1.40.31 # via maggma -botocore==1.40.29 +botocore==1.40.31 # via # boto3 # s3transfer @@ -83,7 +83,7 @@ msgpack==1.1.1 # via # maggma # mp-api (pyproject.toml) -narwhals==2.4.0 +narwhals==2.5.0 # via plotly networkx==3.5 # via pymatgen @@ -123,7 +123,7 @@ pybtex==0.25.1 # via emmet-core pycparser==2.23 # via cffi -pydantic==2.11.7 +pydantic==2.11.9 # via # emmet-core # maggma @@ -149,7 +149,7 @@ pymongo==4.10.1 # via maggma pynacl==1.6.0 # via paramiko -pyparsing==3.2.3 +pyparsing==3.2.4 # via # bibtexparser # matplotlib diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index 3770d3d18..a2d27d1ac 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -29,11 +29,11 @@ bibtexparser==1.4.3 # via pymatgen boltons==25.0.0 # via mpcontribs-client -boto3==1.40.29 +boto3==1.40.31 # via # maggma # mp-api (pyproject.toml) -botocore==1.40.29 +botocore==1.40.31 # via # boto3 # s3transfer @@ -76,6 +76,8 @@ docutils==0.21.2 # via sphinx emmet-core[all]==0.84.10rc2 # via mp-api (pyproject.toml) +execnet==2.1.1 + # via pytest-xdist executing==2.2.1 # via stack-data filelock==3.19.1 @@ -221,7 +223,7 @@ mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.4.0 +narwhals==2.5.0 # via plotly networkx==3.5 # via @@ -343,7 +345,7 @@ pycodestyle==2.14.0 # mp-api (pyproject.toml) pycparser==2.23 # via cffi -pydantic==2.11.7 +pydantic==2.11.9 # via # emmet-core # maggma @@ -395,7 +397,7 @@ pymongo==4.10.1 # mpcontribs-client pynacl==1.6.0 # via paramiko -pyparsing==3.2.3 +pyparsing==3.2.4 # via # bibtexparser # matplotlib @@ -405,13 +407,16 @@ pytest==8.4.2 # pytest-asyncio # pytest-cov # pytest-mock + # pytest-xdist # solvation-analysis -pytest-asyncio==1.1.0 +pytest-asyncio==1.2.0 # via mp-api (pyproject.toml) pytest-cov==7.0.0 # via mp-api (pyproject.toml) pytest-mock==3.15.0 # via mp-api (pyproject.toml) +pytest-xdist==3.8.0 + # via mp-api (pyproject.toml) python-dateutil==2.9.0.post0 # via # arrow @@ -582,7 +587,7 @@ typeguard==4.4.4 # via inflect types-python-dateutil==2.9.0.20250822 # via arrow -types-requests==2.32.4.20250809 +types-requests==2.32.4.20250913 # via mp-api (pyproject.toml) types-setuptools==80.9.0.20250822 # via mp-api (pyproject.toml) @@ -598,6 +603,7 @@ typing-extensions==4.15.0 # pydantic # pydantic-core # pydash + # pytest-asyncio # referencing # spglib # swagger-spec-validator diff --git a/tests/materials/test_electronic_structure.py b/tests/materials/test_electronic_structure.py index 576784258..99b2060ca 100644 --- a/tests/materials/test_electronic_structure.py +++ b/tests/materials/test_electronic_structure.py @@ -47,7 +47,7 @@ def es_rester(): @pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.") -@pytest.mark.skip(reason="magnetic ordering fields not build correctly") +@pytest.mark.skip(reason="magnetic ordering fields not built correctly") def test_es_client(es_rester): search_method = es_rester.search @@ -81,7 +81,7 @@ def bs_rester(): @pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.") -@pytest.mark.skip(reason="magnetic ordering fields not build correctly") +@pytest.mark.skip(reason="magnetic ordering fields not built correctly") def test_bs_client(bs_rester): # Get specific search method search_method = bs_rester.search @@ -127,7 +127,7 @@ def dos_rester(): @pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.") -@pytest.mark.skip(reason="magnetic ordering fields not build correctly") +@pytest.mark.skip(reason="magnetic ordering fields not built correctly") def test_dos_client(dos_rester): search_method = dos_rester.search