Skip to content

Commit b885aa1

Browse files
re-enable and tweak electronic structure tests
1 parent 5073026 commit b885aa1

File tree

2 files changed

+108
-74
lines changed

2 files changed

+108
-74
lines changed

mp_api/client/routes/materials/electronic_structure.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def search(
7575
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
7676
chunk_size (int): Number of data entries per chunk.
7777
all_fields (bool): Whether to return all fields in the document. Defaults to True.
78-
fields (List[str]): List of fields in EOSDoc to return data for.
78+
fields (List[str]): List of fields in ElectronicStructureDoc to return data for.
7979
Default is material_id and last_updated if all_fields is False.
8080
8181
Returns:
@@ -186,8 +186,8 @@ def search(
186186
efermi: tuple[float, float] | None = None,
187187
is_gap_direct: bool | None = None,
188188
is_metal: bool | None = None,
189-
magnetic_ordering: Ordering | None = None,
190-
path_type: BSPathType = BSPathType.setyawan_curtarolo,
189+
magnetic_ordering: Ordering | str | None = None,
190+
path_type: BSPathType | str = BSPathType.setyawan_curtarolo,
191191
num_chunks: int | None = None,
192192
chunk_size: int = 1000,
193193
all_fields: bool = True,
@@ -200,8 +200,8 @@ def search(
200200
efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider.
201201
is_gap_direct (bool): Whether the material has a direct band gap.
202202
is_metal (bool): Whether the material is considered a metal.
203-
magnetic_ordering (Ordering): Magnetic ordering of the material.
204-
path_type (BSPathType): k-path selection convention for the band structure.
203+
magnetic_ordering (Ordering or str): Magnetic ordering of the material.
204+
path_type (BSPathType or str): k-path selection convention for the band structure.
205205
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
206206
chunk_size (int): Number of data entries per chunk.
207207
all_fields (bool): Whether to return all fields in the document. Defaults to True.
@@ -213,7 +213,9 @@ def search(
213213
"""
214214
query_params = defaultdict(dict) # type: dict
215215

216-
query_params["path_type"] = path_type.value
216+
query_params["path_type"] = (
217+
BSPathType[path_type] if isinstance(path_type, str) else path_type
218+
).value
217219

218220
if band_gap:
219221
query_params.update(
@@ -224,7 +226,15 @@ def search(
224226
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
225227

226228
if magnetic_ordering:
227-
query_params.update({"magnetic_ordering": magnetic_ordering.value})
229+
query_params.update(
230+
{
231+
"magnetic_ordering": (
232+
Ordering(magnetic_ordering)
233+
if isinstance(magnetic_ordering, str)
234+
else magnetic_ordering
235+
).value
236+
}
237+
)
228238

229239
if is_gap_direct is not None:
230240
query_params.update({"is_gap_direct": is_gap_direct})
@@ -351,11 +361,11 @@ def search(
351361
self,
352362
band_gap: tuple[float, float] | None = None,
353363
efermi: tuple[float, float] | None = None,
354-
element: Element | None = None,
355-
magnetic_ordering: Ordering | None = None,
356-
orbital: OrbitalType | None = None,
357-
projection_type: DOSProjectionType = DOSProjectionType.total,
358-
spin: Spin = Spin.up,
364+
element: Element | str | None = None,
365+
magnetic_ordering: Ordering | str | None = None,
366+
orbital: OrbitalType | str | None = None,
367+
projection_type: DOSProjectionType | str = DOSProjectionType.total,
368+
spin: Spin | str = Spin.up,
359369
num_chunks: int | None = None,
360370
chunk_size: int = 1000,
361371
all_fields: bool = True,
@@ -366,30 +376,54 @@ def search(
366376
Arguments:
367377
band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider.
368378
efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider.
369-
element (Element): Element for element-projected dos data.
370-
magnetic_ordering (Ordering): Magnetic ordering of the material.
371-
orbital (OrbitalType): Orbital for orbital-projected dos data.
372-
projection_type (DOSProjectionType): Projection type of dos data. Default is the total dos.
373-
spin (Spin): Spin channel of dos data. If non spin-polarized data is stored in Spin.up
379+
element (Element or str): Element for element-projected dos data.
380+
magnetic_ordering (Ordering or str): Magnetic ordering of the material.
381+
orbital (OrbitalType or str): Orbital for orbital-projected dos data.
382+
projection_type (DOSProjectionType or str): Projection type of dos data. Default is the total dos.
383+
spin (Spin or str): Spin channel of dos data. If non spin-polarized data is stored in Spin.up
374384
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
375385
chunk_size (int): Number of data entries per chunk.
376386
all_fields (bool): Whether to return all fields in the document. Defaults to True.
377-
fields (List[str]): List of fields in EOSDoc to return data for.
387+
fields (List[str]): List of fields in ElectronicStructureDoc to return data for.
378388
Default is material_id and last_updated if all_fields is False.
379389
380390
Returns:
381391
([ElectronicStructureDoc]) List of electronic structure documents
382392
"""
383393
query_params = defaultdict(dict) # type: dict
384394

385-
query_params["projection_type"] = projection_type.value
386-
query_params["spin"] = spin.value
395+
query_params["projection_type"] = (
396+
DOSProjectionType[projection_type]
397+
if isinstance(projection_type, str)
398+
else projection_type
399+
).value
400+
query_params["spin"] = (Spin[spin] if isinstance(spin, str) else spin).value
401+
402+
if (
403+
query_params["projection_type"] == DOSProjectionType.elemental.value
404+
and element is None
405+
):
406+
raise MPRestError(
407+
"To query element-projected DOS, you must also specify the `element` onto which the DOS is projected."
408+
)
409+
410+
if (
411+
query_params["projection_type"] == DOSProjectionType.orbital.value
412+
and orbital is None
413+
):
414+
raise MPRestError(
415+
"To query orbital-projected DOS, you must also specify the `orbital` character onto which the DOS is projected."
416+
)
387417

388418
if element:
389-
query_params["element"] = element.value
419+
query_params["element"] = (
420+
Element[element] if isinstance(element, str) else element
421+
).value
390422

391423
if orbital:
392-
query_params["orbital"] = orbital.value
424+
query_params["orbital"] = (
425+
OrbitalType[orbital] if isinstance(orbital, str) else orbital
426+
).value
393427

394428
if band_gap:
395429
query_params.update(
@@ -400,7 +434,15 @@ def search(
400434
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
401435

402436
if magnetic_ordering:
403-
query_params.update({"magnetic_ordering": magnetic_ordering.value})
437+
query_params.update(
438+
{
439+
"magnetic_ordering": (
440+
Ordering[magnetic_ordering]
441+
if isinstance(magnetic_ordering, str)
442+
else magnetic_ordering
443+
).value
444+
}
445+
)
404446

405447
query_params = {
406448
entry: query_params[entry]

tests/materials/test_electronic_structure.py

Lines changed: 43 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def es_rester():
4747

4848

4949
@requires_api_key
50-
@pytest.mark.skip(reason="magnetic ordering fields not built correctly")
5150
def test_es_client(es_rester):
5251
search_method = es_rester.search
5352

@@ -66,51 +65,51 @@ def test_es_client(es_rester):
6665
"is_gap_direct": True,
6766
"efermi": (0, 100),
6867
"band_gap": (0, 5),
68+
"path_type": "hinuma",
6969
}
7070

7171
bs_sub_doc_fields = ["bandstructure"]
7272

7373
bs_alt_name_dict = {} # type: dict
7474

7575

76-
@pytest.fixture
77-
def bs_rester():
78-
rester = BandStructureRester()
79-
yield rester
80-
rester.session.close()
81-
82-
8376
@requires_api_key
84-
@pytest.mark.skip(reason="magnetic ordering fields not built correctly")
85-
def test_bs_client(bs_rester):
77+
def test_bs_client():
8678
# Get specific search method
87-
search_method = bs_rester.search
8879

89-
# Query fields
90-
for param in bs_custom_field_tests:
91-
project_field = bs_alt_name_dict.get(param, None)
92-
q = {
93-
param: bs_custom_field_tests[param],
94-
"chunk_size": 1,
95-
"num_chunks": 1,
96-
}
97-
doc = search_method(**q)[0].model_dump()
80+
with BandStructureRester() as bs_rester:
81+
# Query fields
82+
for param in bs_custom_field_tests:
83+
project_field = bs_alt_name_dict.get(param, None)
84+
q = {
85+
param: bs_custom_field_tests[param],
86+
"chunk_size": 1,
87+
"num_chunks": 1,
88+
}
89+
doc = bs_rester.search(**q)[0].model_dump()
9890

99-
for sub_field in bs_sub_doc_fields:
100-
if sub_field in doc:
101-
doc = doc[sub_field]
91+
for sub_field in bs_sub_doc_fields:
92+
if sub_field in doc:
93+
doc = doc[sub_field]
10294

103-
if param != "path_type":
104-
doc = doc["setyawan_curtarolo"]
95+
if param != "path_type":
96+
doc = doc["setyawan_curtarolo"]
10597

106-
assert doc[project_field if project_field is not None else param] is not None
98+
assert (
99+
doc[project_field if project_field is not None else param]
100+
is not None
101+
)
107102

108103

109-
dos_custom_field_tests = {
110-
"magnetic_ordering": Ordering.FM,
111-
"efermi": (0, 100),
112-
"band_gap": (0, 5),
113-
}
104+
dos_custom_field_tests = [
105+
{"magnetic_ordering": Ordering.FM},
106+
{"efermi": (1, 1.1)},
107+
{"band_gap": (8.0, 9.0)},
108+
{"projection_type": "elemental", "element": "As"},
109+
{
110+
"magnetic_ordering": "FM",
111+
},
112+
]
114113

115114
dos_excluded_params = ["orbital", "element"]
116115

@@ -119,35 +118,28 @@ def test_bs_client(bs_rester):
119118
dos_alt_name_dict = {} # type: dict
120119

121120

122-
@pytest.fixture
123-
def dos_rester():
124-
rester = DosRester()
125-
yield rester
126-
rester.session.close()
127-
128-
129121
@requires_api_key
130-
@pytest.mark.skip(reason="magnetic ordering fields not built correctly")
131-
def test_dos_client(dos_rester):
132-
search_method = dos_rester.search
133-
134-
# Query fields
135-
for param in dos_custom_field_tests:
136-
if param not in dos_excluded_params:
137-
project_field = dos_alt_name_dict.get(param, None)
122+
def test_dos_client():
123+
with DosRester() as dos_rester:
124+
# Query fields
125+
for params in dos_custom_field_tests:
126+
if any(param in dos_excluded_params for param in params):
127+
continue
138128
q = {
139-
param: dos_custom_field_tests[param],
129+
**params,
140130
"chunk_size": 1,
141131
"num_chunks": 1,
142132
}
143-
doc = search_method(**q)[0].model_dump()
133+
doc = dos_rester.search(**q)[0].model_dump()
144134
for sub_field in dos_sub_doc_fields:
145135
if sub_field in doc:
146136
doc = doc[sub_field]
147137

148-
if param != "projection_type" and param != "magnetic_ordering":
138+
if not any(
139+
param in params for param in {"projection_type", "magnetic_ordering"}
140+
):
149141
doc = doc["total"]["1"]
150142

151-
assert (
152-
doc[project_field if project_field is not None else param] is not None
143+
assert all(
144+
doc[dos_alt_name_dict.get(param, param)] is not None for param in params
153145
)

0 commit comments

Comments
 (0)