Skip to content

Commit 7de3a4b

Browse files
add phonon tests
1 parent 8a222a7 commit 7de3a4b

File tree

3 files changed

+109
-16
lines changed

3 files changed

+109
-16
lines changed

mp_api/client/routes/materials/phonon.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
from typing import TYPE_CHECKING
45

56
import numpy as np
67
from emmet.core.phonon import PhononBS, PhononBSDOSDoc, PhononDOS
78

89
from mp_api.client.core import BaseRester, MPRestError
910
from mp_api.client.core.utils import validate_ids
1011

12+
if TYPE_CHECKING:
13+
from typing import Any
14+
from emmet.core.math import Matrix3D
15+
1116

1217
class PhononRester(BaseRester):
1318
suffix = "materials/phonon"
@@ -63,7 +68,7 @@ def search(
6368
**query_params,
6469
)
6570

66-
def get_bandstructure_from_material_id(self, material_id: str, phonon_method: str):
71+
def get_bandstructure_from_material_id(self, material_id: str, phonon_method: str) -> PhononBS | dict[str,Any]:
6772
"""Get the phonon band structure pymatgen object associated with a given material ID and phonon method.
6873
6974
Arguments:
@@ -73,10 +78,13 @@ def get_bandstructure_from_material_id(self, material_id: str, phonon_method: st
7378
Returns:
7479
bandstructure (PhononBS): PhononBS object
7580
"""
76-
result = self._query_open_data(
77-
bucket="materialsproject-parsed",
78-
key=f"ph-bandstructures/{phonon_method}/{material_id}.json.gz",
79-
)[0]
81+
try:
82+
result = self._query_open_data(
83+
bucket="materialsproject-parsed",
84+
key=f"ph-bandstructures/{phonon_method}/{material_id}.json.gz",
85+
)[0]
86+
except OSError:
87+
result = None
8088

8189
if not result or not result[0]:
8290
raise MPRestError("No object found")
@@ -86,7 +94,7 @@ def get_bandstructure_from_material_id(self, material_id: str, phonon_method: st
8694

8795
return result[0]
8896

89-
def get_dos_from_material_id(self, material_id: str, phonon_method: str):
97+
def get_dos_from_material_id(self, material_id: str, phonon_method: str) -> PhononDOS | dict[str,Any]:
9098
"""Get the phonon dos pymatgen object associated with a given material ID and phonon method.
9199
92100
Arguments:
@@ -96,10 +104,13 @@ def get_dos_from_material_id(self, material_id: str, phonon_method: str):
96104
Returns:
97105
dos (PhononDOS): PhononDOS object
98106
"""
99-
result = self._query_open_data(
100-
bucket="materialsproject-parsed",
101-
key=f"ph-dos/{phonon_method}/{material_id}.json.gz",
102-
)[0]
107+
try:
108+
result = self._query_open_data(
109+
bucket="materialsproject-parsed",
110+
key=f"ph-dos/{phonon_method}/{material_id}.json.gz",
111+
)[0]
112+
except OSError:
113+
result = None
103114

104115
if not result or not result[0]:
105116
raise MPRestError("No object found")
@@ -109,7 +120,7 @@ def get_dos_from_material_id(self, material_id: str, phonon_method: str):
109120

110121
return result[0]
111122

112-
def get_forceconstants_from_material_id(self, material_id: str):
123+
def get_forceconstants_from_material_id(self, material_id: str) -> list[list[Matrix3D]]:
113124
"""Get the force constants associated with a given material ID.
114125
115126
Arguments:
@@ -118,10 +129,13 @@ def get_forceconstants_from_material_id(self, material_id: str):
118129
Returns:
119130
force constants (list[list[Matrix3D]]): PhononDOS object
120131
"""
121-
result = self._query_open_data(
122-
bucket="materialsproject-parsed",
123-
key=f"ph-force-constants/{material_id}.json.gz",
124-
)[0]
132+
try:
133+
result = self._query_open_data(
134+
bucket="materialsproject-parsed",
135+
key=f"ph-force-constants/{material_id}.json.gz",
136+
)[0]
137+
except OSError:
138+
result = None
125139

126140
if not result or not result[0]:
127141
raise MPRestError("No object found")

tests/materials/test_phonon.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
2+
import os
3+
4+
import numpy as np
5+
import pytest
6+
7+
from emmet.core.phonon import PhononBS, PhononDOS
8+
9+
from mp_api.client.core import MPRestError
10+
from mp_api.client.routes.materials.phonon import PhononRester
11+
12+
from core_function import client_search_testing
13+
14+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
15+
def test_phonon_search():
16+
client_search_testing(
17+
search_method=PhononRester().search,
18+
excluded_params=[
19+
"num_chunks",
20+
"chunk_size",
21+
"all_fields",
22+
"fields",
23+
],
24+
alt_name_dict={
25+
"material_ids": "material_id",
26+
},
27+
custom_field_tests = {
28+
"material_ids": ["mp-149","mp-13"],
29+
"material_ids": "mp-149",
30+
"phonon_method": "dfpt",
31+
},
32+
sub_doc_fields=[],
33+
)
34+
35+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
36+
@pytest.mark.parametrize("use_document_model",[True,False])
37+
def test_phonon_get_methods(use_document_model):
38+
39+
rester = PhononRester(use_document_model=use_document_model)
40+
41+
# TODO: update when there is force constant data
42+
for func_name, schema in {
43+
"bandstructure" : PhononBS,
44+
"dos": PhononDOS,
45+
# "forceconstants": list
46+
}.items():
47+
search_method = getattr(
48+
rester,
49+
f"get_{func_name}_from_material_id",
50+
)
51+
assert isinstance(
52+
search_method("mp-149","dfpt"),
53+
schema if use_document_model else dict
54+
)
55+
56+
with pytest.raises(MPRestError,match="No object found"):
57+
_ = search_method("mp-0","dfpt")
58+
59+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
60+
@pytest.mark.parametrize("use_document_model",[True,False])
61+
def test_phonon_thermo(use_document_model):
62+
63+
with pytest.raises(MPRestError,match="No phonon document found"):
64+
_ = PhononRester(
65+
use_document_model=use_document_model
66+
).compute_thermo_quantities("mp-0","dfpt")
67+
68+
thermo_props = PhononRester(
69+
use_document_model=use_document_model
70+
).compute_thermo_quantities("mp-149","dfpt")
71+
72+
# Default set in the method
73+
num_vals = 100
74+
75+
assert all(
76+
isinstance(v, np.ndarray if k == "temperature" else list)
77+
and len(v) == num_vals
78+
for k, v in thermo_props.items()
79+
)

tests/materials/test_similarity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_struct():
2828
return Structure.from_str(poscar,fmt="poscar")
2929

3030
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
31-
def test_client():
31+
def test_similarity_search():
3232
client_search_testing(
3333
search_method=SimilarityRester().search,
3434
excluded_params=[

0 commit comments

Comments
 (0)