Skip to content

Commit 778e245

Browse files
precommit
1 parent 7de3a4b commit 778e245

File tree

11 files changed

+111
-105
lines changed

11 files changed

+111
-105
lines changed

mp_api/client/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import re
43
from typing import TYPE_CHECKING, Literal
54

65
import orjson
@@ -12,7 +11,7 @@
1211
from mp_api.client.core.settings import MAPIClientSettings
1312

1413
if TYPE_CHECKING:
15-
from monty.json import MSONable
14+
pass
1615

1716

1817
def _compare_emmet_ver(
@@ -40,6 +39,7 @@ def _compare_emmet_ver(
4039
f"__{op_to_op.get(op,op)}__",
4140
)(parse_version(ref_version))
4241

42+
4343
def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"):
4444
"""Utility to load json in consistent manner."""
4545
data = orjson.loads(

mp_api/client/routes/materials/phonon.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
if TYPE_CHECKING:
1313
from typing import Any
14+
1415
from emmet.core.math import Matrix3D
1516

1617

@@ -68,7 +69,9 @@ def search(
6869
**query_params,
6970
)
7071

71-
def get_bandstructure_from_material_id(self, material_id: str, phonon_method: str) -> PhononBS | dict[str,Any]:
72+
def get_bandstructure_from_material_id(
73+
self, material_id: str, phonon_method: str
74+
) -> PhononBS | dict[str, Any]:
7275
"""Get the phonon band structure pymatgen object associated with a given material ID and phonon method.
7376
7477
Arguments:
@@ -94,7 +97,9 @@ def get_bandstructure_from_material_id(self, material_id: str, phonon_method: st
9497

9598
return result[0]
9699

97-
def get_dos_from_material_id(self, material_id: str, phonon_method: str) -> PhononDOS | dict[str,Any]:
100+
def get_dos_from_material_id(
101+
self, material_id: str, phonon_method: str
102+
) -> PhononDOS | dict[str, Any]:
98103
"""Get the phonon dos pymatgen object associated with a given material ID and phonon method.
99104
100105
Arguments:
@@ -120,7 +125,9 @@ def get_dos_from_material_id(self, material_id: str, phonon_method: str) -> Phon
120125

121126
return result[0]
122127

123-
def get_forceconstants_from_material_id(self, material_id: str) -> list[list[Matrix3D]]:
128+
def get_forceconstants_from_material_id(
129+
self, material_id: str
130+
) -> list[list[Matrix3D]]:
124131
"""Get the force constants associated with a given material ID.
125132
126133
Arguments:

mp_api/client/routes/materials/xas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mp_api.client.core.utils import validate_ids
1010

1111
if TYPE_CHECKING:
12-
from emmet.core.types.enums import XasEdge, XasType
12+
from emmet.core.types.enums import XasEdge, XasType
1313

1414

1515
class XASRester(BaseRester):

tests/core/test_oxygen_evolution.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from pymatgen.core import Composition
1010
from pymatgen.entries.computed_entries import ComputedEntry
1111

12-
from mp_api.client.core._oxygen_evolution import DEFAULT_CACHE_FILE, NIST_JANAF_O2_MU_T, OxygenEvolution
12+
from mp_api.client.core._oxygen_evolution import (
13+
DEFAULT_CACHE_FILE,
14+
NIST_JANAF_O2_MU_T,
15+
OxygenEvolution,
16+
)
1317

1418

1519
def test_interp():
@@ -38,19 +42,16 @@ def test_interp():
3842
with pytest.warns(UserWarning, match="outside the fitting range"):
3943
oxyevo.mu_to_temp_spline(badval)
4044

45+
4146
def test_get():
4247
"""Test data retrieval from NIST."""
4348

4449
data = {}
4550
data["temperature"], data["mu-mu_0K"] = OxygenEvolution().get_chempot_temp_data()
4651
assert DEFAULT_CACHE_FILE.exists()
47-
assert all(
48-
np.allclose(data[k],v) for k, v in NIST_JANAF_O2_MU_T.items()
49-
)
52+
assert all(np.allclose(data[k], v) for k, v in NIST_JANAF_O2_MU_T.items())
5053
json_data = json.loads(DEFAULT_CACHE_FILE.read_text())
51-
assert all(
52-
np.allclose(json_data[k],v) for k, v in NIST_JANAF_O2_MU_T.items()
53-
)
54+
assert all(np.allclose(json_data[k], v) for k, v in NIST_JANAF_O2_MU_T.items())
5455

5556

5657
def test_oxy_evo():

tests/core/test_utils.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,43 @@
66
from emmet.core.mpid import MPID, AlphaID
77

88

9-
def test_emmet_core_version_checks(monkeypatch : pytest.MonkeyPatch):
10-
11-
ref_ver = (1,2,"3rc5")
9+
def test_emmet_core_version_checks(monkeypatch: pytest.MonkeyPatch):
10+
ref_ver = (1, 2, "3rc5")
1211
ref_ver_str = ".".join(str(x) for x in ref_ver)
1312

1413
import emmet.core
15-
monkeypatch.setattr(emmet.core,"__version__",ref_ver_str)
14+
15+
monkeypatch.setattr(emmet.core, "__version__", ref_ver_str)
1616
from mp_api.client.core.utils import _compare_emmet_ver
1717

18-
assert _compare_emmet_ver(ref_ver_str,"==")
18+
assert _compare_emmet_ver(ref_ver_str, "==")
1919

20-
next_ver = ".".join(str(x) for x in [ref_ver[0] + 1,*ref_ver[1:]])
21-
assert _compare_emmet_ver(next_ver,"<")
22-
assert _compare_emmet_ver(next_ver,"<=")
20+
next_ver = ".".join(str(x) for x in [ref_ver[0] + 1, *ref_ver[1:]])
21+
assert _compare_emmet_ver(next_ver, "<")
22+
assert _compare_emmet_ver(next_ver, "<=")
2323

24-
prior_ver = ".".join(str(x) for x in [ref_ver[0],ref_ver[1]-1,*ref_ver[2:]])
25-
assert _compare_emmet_ver(prior_ver,">")
26-
assert _compare_emmet_ver(prior_ver,">=")
24+
prior_ver = ".".join(str(x) for x in [ref_ver[0], ref_ver[1] - 1, *ref_ver[2:]])
25+
assert _compare_emmet_ver(prior_ver, ">")
26+
assert _compare_emmet_ver(prior_ver, ">=")
2727

28-
def test_id_validation():
2928

29+
def test_id_validation():
3030
from mp_api.client.core.utils import validate_ids
3131
from mp_api.client.core.settings import MAPIClientSettings
3232

3333
max_num_idxs = MAPIClientSettings().MAX_LIST_LENGTH
3434

35-
with pytest.raises(ValueError,match="too long"):
36-
_ = validate_ids(
37-
[f"mp-{x}" for x in range(max_num_idxs + 1)]
38-
)
35+
with pytest.raises(ValueError, match="too long"):
36+
_ = validate_ids([f"mp-{x}" for x in range(max_num_idxs + 1)])
3937

4038
# For all legacy MPIDs, ensure these validate correctly
4139
assert all(
42-
isinstance(x,str) and MPID(x).string == x
43-
for x in validate_ids(
44-
[f"mp-{y}" for y in range(max_num_idxs)]
45-
)
40+
isinstance(x, str) and MPID(x).string == x
41+
for x in validate_ids([f"mp-{y}" for y in range(max_num_idxs)])
4642
)
4743

4844
# For all new AlphaIDs, ensure these validate correctly
4945
assert all(
50-
isinstance(x,str) and AlphaID(x).string == x
51-
for x in validate_ids(
52-
[y + AlphaID._cut_point for y in range(max_num_idxs)]
53-
)
54-
)
46+
isinstance(x, str) and AlphaID(x).string == x
47+
for x in validate_ids([y + AlphaID._cut_point for y in range(max_num_idxs)])
48+
)

tests/materials/test_phonon.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import os
32

43
import numpy as np
@@ -11,6 +10,7 @@
1110

1211
from core_function import client_search_testing
1312

13+
1414
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
1515
def test_phonon_search():
1616
client_search_testing(
@@ -24,23 +24,23 @@ def test_phonon_search():
2424
alt_name_dict={
2525
"material_ids": "material_id",
2626
},
27-
custom_field_tests = {
28-
"material_ids": ["mp-149","mp-13"],
27+
custom_field_tests={
28+
"material_ids": ["mp-149", "mp-13"],
2929
"material_ids": "mp-149",
3030
"phonon_method": "dfpt",
3131
},
3232
sub_doc_fields=[],
3333
)
3434

35+
3536
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
36-
@pytest.mark.parametrize("use_document_model",[True,False])
37+
@pytest.mark.parametrize("use_document_model", [True, False])
3738
def test_phonon_get_methods(use_document_model):
38-
3939
rester = PhononRester(use_document_model=use_document_model)
4040

4141
# TODO: update when there is force constant data
4242
for func_name, schema in {
43-
"bandstructure" : PhononBS,
43+
"bandstructure": PhononBS,
4444
"dos": PhononDOS,
4545
# "forceconstants": list
4646
}.items():
@@ -49,31 +49,29 @@ def test_phonon_get_methods(use_document_model):
4949
f"get_{func_name}_from_material_id",
5050
)
5151
assert isinstance(
52-
search_method("mp-149","dfpt"),
53-
schema if use_document_model else dict
52+
search_method("mp-149", "dfpt"), schema if use_document_model else dict
5453
)
5554

56-
with pytest.raises(MPRestError,match="No object found"):
57-
_ = search_method("mp-0","dfpt")
55+
with pytest.raises(MPRestError, match="No object found"):
56+
_ = search_method("mp-0", "dfpt")
57+
5858

5959
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
60-
@pytest.mark.parametrize("use_document_model",[True,False])
60+
@pytest.mark.parametrize("use_document_model", [True, False])
6161
def test_phonon_thermo(use_document_model):
62-
63-
with pytest.raises(MPRestError,match="No phonon document found"):
62+
with pytest.raises(MPRestError, match="No phonon document found"):
6463
_ = PhononRester(
6564
use_document_model=use_document_model
66-
).compute_thermo_quantities("mp-0","dfpt")
65+
).compute_thermo_quantities("mp-0", "dfpt")
6766

6867
thermo_props = PhononRester(
6968
use_document_model=use_document_model
70-
).compute_thermo_quantities("mp-149","dfpt")
71-
69+
).compute_thermo_quantities("mp-149", "dfpt")
70+
7271
# Default set in the method
7372
num_vals = 100
74-
73+
7574
assert all(
76-
isinstance(v, np.ndarray if k == "temperature" else list)
77-
and len(v) == num_vals
75+
isinstance(v, np.ndarray if k == "temperature" else list) and len(v) == num_vals
7876
for k, v in thermo_props.items()
79-
)
77+
)

tests/materials/test_similarity.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import os
32

43
import numpy as np
@@ -12,6 +11,7 @@
1211

1312
from core_function import client_search_testing
1413

14+
1515
@pytest.fixture(scope="module")
1616
def test_struct():
1717
poscar = """Al2
@@ -25,7 +25,8 @@ def test_struct():
2525
0.875 0.875 0.875 Al
2626
0.125 0.125 0.125 Al
2727
"""
28-
return Structure.from_str(poscar,fmt="poscar")
28+
return Structure.from_str(poscar, fmt="poscar")
29+
2930

3031
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
3132
def test_similarity_search():
@@ -40,45 +41,48 @@ def test_similarity_search():
4041
alt_name_dict={
4142
"material_ids": "material_id",
4243
},
43-
custom_field_tests = {
44-
"material_ids": ["mp-149","mp-13"],
45-
"material_ids": "mp-149"
44+
custom_field_tests={
45+
"material_ids": ["mp-149", "mp-13"],
46+
"material_ids": "mp-149",
4647
},
4748
sub_doc_fields=[],
4849
)
4950

51+
5052
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
5153
def test_similarity_vector_search(test_struct):
52-
5354
rester = SimilarityRester()
5455
fv = rester.fingerprint_structure(test_struct)
55-
assert isinstance(fv,np.ndarray)
56+
assert isinstance(fv, np.ndarray)
5657
assert len(fv) == 122
57-
assert isinstance(rester._fingerprinter,SimilarityScorer)
58-
58+
assert isinstance(rester._fingerprinter, SimilarityScorer)
5959

6060
get_top = 5
61-
sim_entries = rester.find_similar("mp-149",top=get_top)
62-
assert all(
63-
isinstance(entry,SimilarityEntry) for entry in sim_entries
64-
)
61+
sim_entries = rester.find_similar("mp-149", top=get_top)
62+
assert all(isinstance(entry, SimilarityEntry) for entry in sim_entries)
6563
assert len(sim_entries) == get_top
6664

67-
sim_dict_entries = SimilarityRester(use_document_model=False).find_similar("mp-149",top=get_top)
65+
sim_dict_entries = SimilarityRester(use_document_model=False).find_similar(
66+
"mp-149", top=get_top
67+
)
6868
assert all(
69-
isinstance(entry,dict) and all(
70-
k in entry for k in SimilarityEntry.model_fields
71-
)
69+
isinstance(entry, dict)
70+
and all(k in entry for k in SimilarityEntry.model_fields)
7271
for entry in sim_dict_entries
7372
)
7473

75-
with pytest.raises(MPRestError,match="No similarity data available for"):
74+
with pytest.raises(MPRestError, match="No similarity data available for"):
7675
_ = rester.find_similar("mp-0")
7776

7877
assert all(
79-
isinstance(entry,SimilarityEntry) and isinstance(entry.dissimilarity,float)
80-
for entry in rester.find_similar(test_struct, top = 2,)
78+
isinstance(entry, SimilarityEntry) and isinstance(entry.dissimilarity, float)
79+
for entry in rester.find_similar(
80+
test_struct,
81+
top=2,
82+
)
8183
)
8284

83-
with pytest.raises(MPRestError,match="Please submit a pymatgen Structure or MP ID"):
84-
_ = rester.find_similar(fv)
85+
with pytest.raises(
86+
MPRestError, match="Please submit a pymatgen Structure or MP ID"
87+
):
88+
_ = rester.find_similar(fv)

0 commit comments

Comments
 (0)