Skip to content

Commit 66a9103

Browse files
skip test for missing api key decorator
1 parent c97e9bb commit 66a9103

34 files changed

+137
-116
lines changed

mp_api/client/routes/materials/electronic_structure.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import warnings
44
from collections import defaultdict
5+
from typing import TYPE_CHECKING
56

67
from emmet.core.electronic_structure import (
78
BSPathType,
@@ -15,6 +16,9 @@
1516
from mp_api.client.core import BaseRester, MPRestError
1617
from mp_api.client.core.utils import validate_ids
1718

19+
if TYPE_CHECKING:
20+
from pymatgen.electronic_structure.dos import CompleteDos
21+
1822

1923
class ElectronicStructureRester(BaseRester):
2024
suffix = "materials/electronic_structure"
@@ -142,9 +146,28 @@ def search(
142146
)
143147

144148

145-
class BandStructureRester(BaseRester):
149+
class BaseESPropertyRester(BaseRester):
150+
_es_rester: ElectronicStructureRester | None = None
151+
document_model = ElectronicStructureDoc
152+
153+
@property
154+
def es_rester(self) -> ElectronicStructureRester:
155+
if not self._es_rester:
156+
self._es_rester = ElectronicStructureRester(
157+
api_key=self.api_key,
158+
endpoint=self.base_endpoint,
159+
include_user_agent=self.include_user_agent,
160+
session=self.session,
161+
monty_decode=self.monty_decode,
162+
use_document_model=self.use_document_model,
163+
headers=self.headers,
164+
mute_progress_bars=self.mute_progress_bars,
165+
)
166+
return self._es_rester
167+
168+
169+
class BandStructureRester(BaseESPropertyRester):
146170
suffix = "materials/electronic_structure/bandstructure"
147-
document_model = ElectronicStructureDoc # type: ignore
148171

149172
def search_bandstructure_summary(self, *args, **kwargs): # pragma: no cover
150173
"""Deprecated."""
@@ -258,19 +281,8 @@ def get_bandstructure_from_material_id(
258281
Returns:
259282
bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object
260283
"""
261-
es_rester = ElectronicStructureRester(
262-
api_key=self.api_key,
263-
endpoint=self.base_endpoint,
264-
include_user_agent=self.include_user_agent,
265-
session=self.session,
266-
monty_decode=self.monty_decode,
267-
use_document_model=self.use_document_model,
268-
headers=self.headers,
269-
mute_progress_bars=self.mute_progress_bars,
270-
)
271-
272284
if line_mode:
273-
bs_doc = es_rester.search(
285+
bs_doc = self.es_rester.search(
274286
material_ids=material_id, fields=["bandstructure"]
275287
)
276288
if not bs_doc:
@@ -293,7 +305,9 @@ def get_bandstructure_from_material_id(
293305

294306
else:
295307
if not (
296-
bs_doc := es_rester.search(material_ids=material_id, fields=["dos"])
308+
bs_doc := self.es_rester.search(
309+
material_ids=material_id, fields=["dos"]
310+
)
297311
):
298312
raise MPRestError("No electronic structure data found.")
299313

@@ -319,9 +333,8 @@ def get_bandstructure_from_material_id(
319333
raise MPRestError("No band structure object found.")
320334

321335

322-
class DosRester(BaseRester):
336+
class DosRester(BaseESPropertyRester):
323337
suffix = "materials/electronic_structure/dos"
324-
document_model = ElectronicStructureDoc # type: ignore
325338

326339
def search_dos_summary(self, *args, **kwargs): # pragma: no cover
327340
"""Deprecated."""
@@ -403,7 +416,7 @@ def search(
403416
**query_params,
404417
)
405418

406-
def get_dos_from_task_id(self, task_id: str):
419+
def get_dos_from_task_id(self, task_id: str) -> CompleteDos:
407420
"""Get the density of states pymatgen object associated with a given calculation ID.
408421
409422
Arguments:
@@ -431,18 +444,9 @@ def get_dos_from_material_id(self, material_id: str):
431444
Returns:
432445
dos (CompleteDos): CompleteDos object
433446
"""
434-
es_rester = ElectronicStructureRester(
435-
api_key=self.api_key,
436-
endpoint=self.base_endpoint,
437-
include_user_agent=self.include_user_agent,
438-
session=self.session,
439-
monty_decode=self.monty_decode,
440-
use_document_model=self.use_document_model,
441-
headers=self.headers,
442-
mute_progress_bars=self.mute_progress_bars,
443-
)
444-
445-
if not (dos_doc := es_rester.search(material_ids=material_id, fields=["dos"])):
447+
if not (
448+
dos_doc := self.es_rester.search(material_ids=material_id, fields=["dos"])
449+
):
446450
return None
447451

448452
if not (dos_data := dos_doc[0].get("dos")):
Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
from __future__ import annotations
2-
from typing import Callable, Any
2+
3+
import os
4+
from typing import TYPE_CHECKING
5+
6+
import pytest
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Callable
10+
from typing import Any
11+
12+
requires_api_key = pytest.mark.skipif(
13+
os.getenv("MP_API_KEY") is None,
14+
reason="No API key found.",
15+
)
316

417

518
def client_search_testing(

tests/materials/__init__.py

Whitespace-only changes.

tests/materials/test_absorption.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from mp_api.client.core import MPRestError
88
from mp_api.client.routes.materials.absorption import AbsorptionRester
99

10-
from core_function import client_search_testing
10+
from ..conftest import client_search_testing, requires_api_key
1111

1212

13-
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
13+
@requires_api_key
1414
def test_absorption_search():
1515
client_search_testing(
1616
search_method=AbsorptionRester().search,

tests/materials/test_alloys.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from mp_api.client.routes.materials.alloys import AlloysRester
66

7-
from core_function import client_search_testing
7+
from ..conftest import client_search_testing, requires_api_key
88

99

10-
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
10+
@requires_api_key
1111
def test_alloys_search():
1212
client_search_testing(
1313
search_method=AlloysRester().search,

tests/materials/test_bonds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from core_function import client_search_testing
2+
from ..conftest import client_search_testing, requires_api_key
33
import pytest
44
from mp_api.client.routes.materials.bonds import BondsRester
55

@@ -35,7 +35,7 @@ def rester():
3535
} # type: dict
3636

3737

38-
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
38+
@requires_api_key
3939
def test_client(rester):
4040
search_method = rester.search
4141

tests/materials/test_chemenv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from core_function import client_search_testing
2+
from ..conftest import client_search_testing, requires_api_key
33

44
import pytest
55

@@ -44,7 +44,7 @@ def rester():
4444
} # type: dict
4545

4646

47-
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
47+
@requires_api_key
4848
def test_client(rester):
4949
search_method = rester.search
5050

tests/materials/test_dielectric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from core_function import client_search_testing
2+
from ..conftest import client_search_testing, requires_api_key
33

44
import pytest
55

@@ -33,7 +33,7 @@ def rester():
3333
} # type: dict
3434

3535

36-
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
36+
@requires_api_key
3737
def test_client(rester):
3838
search_method = rester.search
3939

tests/materials/test_doi.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
import os
2-
3-
import pytest
4-
51
from mp_api.client.routes.materials import DOIRester
62

7-
from core_function import client_search_testing
3+
from ..conftest import client_search_testing, requires_api_key
84

95

10-
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
6+
@requires_api_key
117
def test_doi_search():
128
client_search_testing(
139
search_method=DOIRester().search,

tests/materials/test_elasticity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from core_function import client_search_testing
2+
from ..conftest import client_search_testing, requires_api_key
33

44
import pytest
55

@@ -38,7 +38,7 @@ def rester():
3838
custom_field_tests = {"material_ids": ["mp-149"]} # type: dict
3939

4040

41-
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
41+
@requires_api_key
4242
def test_client(rester):
4343
search_method = rester.search
4444

0 commit comments

Comments
 (0)