Skip to content

Commit 17136ba

Browse files
correct absorption rester + add test
1 parent 0abf835 commit 17136ba

File tree

4 files changed

+98
-31
lines changed

4 files changed

+98
-31
lines changed

mp_api/client/routes/materials/absorption.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ class AbsorptionRester(BaseRester):
1616
def search(
1717
self,
1818
material_ids: str | list[str] | None = None,
19-
chemsys: str | list[str] | None = None,
20-
elements: list[str] | None = None,
21-
exclude_elements: list[str] | None = None,
22-
formula: list[str] | None = None,
19+
num_sites: int | tuple[int, int] | None = None,
20+
num_elements: int | tuple[int, int] | None = None,
21+
volume: float | tuple[float, float] | None = None,
22+
density: float | tuple[float, float] | None = None,
23+
band_gap: float | tuple[float, float] | None = None,
2324
num_chunks: int | None = None,
2425
chunk_size: int = 1000,
2526
all_fields: bool = True,
@@ -29,13 +30,6 @@ def search(
2930
3031
Arguments:
3132
material_ids (str, List[str]): Search for optical absorption data associated with the specified Material IDs
32-
chemsys (str, List[str]): A chemical system or list of chemical systems
33-
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]).
34-
elements (List[str]): A list of elements.
35-
exclude_elements (List[str]): A list of elements to exclude.
36-
formula (str, List[str]): A formula including anonymized formula
37-
or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed
38-
(e.g., [Fe2O3, ABO3]).
3933
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
4034
chunk_size (int): Number of data entries per chunk.
4135
all_fields (bool): Whether to return all fields in the document. Defaults to True.
@@ -46,23 +40,25 @@ def search(
4640
"""
4741
query_params = defaultdict(dict) # type: dict
4842

49-
if formula:
50-
if isinstance(formula, str):
51-
formula = [formula]
52-
53-
query_params.update({"formula": ",".join(formula)})
54-
55-
if chemsys:
56-
if isinstance(chemsys, str):
57-
chemsys = [chemsys]
58-
59-
query_params.update({"chemsys": ",".join(chemsys)})
60-
61-
if elements:
62-
query_params.update({"elements": ",".join(elements)})
63-
64-
if exclude_elements:
65-
query_params.update({"exclude_elements": ",".join(exclude_elements)})
43+
aliased = {
44+
"num_sites": "nsites",
45+
"num_elements": "nelements",
46+
"band_gap": "bandgap",
47+
}
48+
user_query = locals()
49+
for k in ("num_sites","num_elements","volume","density","band_gap"):
50+
if (value := user_query.get(k)) is not None:
51+
if k in ("num_sites","num_elements") and isinstance(value, int):
52+
value = (value, value)
53+
elif k in ("volume","density","band_gap") and isinstance(value,int | float):
54+
value = (value - 1e-2, value + 1e-2)
55+
56+
query_params.update(
57+
{
58+
f"{aliased.get(k,k)}_min": value[0],
59+
f"{aliased.get(k,k)}_max": value[1],
60+
}
61+
)
6662

6763
if material_ids:
6864
if isinstance(material_ids, str):
@@ -77,7 +73,6 @@ def search(
7773
}
7874

7975
return super()._search(
80-
formulae=formula,
8176
num_chunks=num_chunks,
8277
chunk_size=chunk_size,
8378
all_fields=all_fields,

mp_api/client/routes/materials/summary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def search( # noqa: D417
4848
magnetic_ordering: Ordering | None = None,
4949
material_ids: str | list[str] | None = None,
5050
n: tuple[float, float] | None = None,
51-
num_elements: tuple[int, int] | None = None,
52-
num_sites: tuple[int, int] | None = None,
51+
num_elements: int | tuple[int, int] | None = None,
52+
num_sites: int | tuple[int, int] | None = None,
5353
num_magnetic_sites: tuple[int, int] | None = None,
5454
num_unique_magnetic_sites: tuple[int, int] | None = None,
5555
piezoelectric_modulus: tuple[float, float] | None = None,

tests/materials/test_absorption.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
import pytest
4+
5+
from emmet.core.phonon import PhononBS, PhononDOS
6+
7+
from mp_api.client.core import MPRestError
8+
from mp_api.client.routes.materials.absorption import AbsorptionRester
9+
10+
from core_function import client_search_testing
11+
12+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
13+
def test_absorption_search():
14+
15+
client_search_testing(
16+
search_method=AbsorptionRester().search,
17+
excluded_params=[
18+
"num_chunks",
19+
"chunk_size",
20+
"all_fields",
21+
"fields",
22+
],
23+
alt_name_dict={
24+
"material_ids": "material_id",
25+
"num_sites": "nsites",
26+
"num_elements": "nelements",
27+
"band_gap": "bandgap",
28+
},
29+
custom_field_tests={
30+
"material_ids": ["mp-149", "mp-239"],
31+
"material_ids": "mp-149",
32+
"num_sites": (6,7),
33+
"num_sites": 7,
34+
"num_elements": 5,
35+
"num_elements": (4,5),
36+
"volume": (115,116),
37+
"volume": 115.5,
38+
"density": (2.9,3),
39+
"density": 2.933,
40+
"band_gap": (1,1.05),
41+
"band_gap": 1.,
42+
},
43+
sub_doc_fields=[],
44+
)

tests/materials/test_doi.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
3+
import pytest
4+
5+
from mp_api.client.routes.materials import DOIRester
6+
7+
from core_function import client_search_testing
8+
9+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
10+
def test_doi_search():
11+
12+
client_search_testing(
13+
search_method=DOIRester().search,
14+
excluded_params=[
15+
"num_chunks",
16+
"chunk_size",
17+
"all_fields",
18+
"fields",
19+
],
20+
alt_name_dict={
21+
"material_ids": "material_id",
22+
},
23+
custom_field_tests={
24+
"material_ids": ["mp-149", "mp-13"],
25+
"material_ids": "mp-149",
26+
},
27+
sub_doc_fields=[],
28+
)

0 commit comments

Comments
 (0)