Skip to content

Commit 7d50936

Browse files
Support for querying symmetry information by lists (#1014)
1 parent 459266f commit 7d50936

File tree

2 files changed

+64
-12
lines changed

2 files changed

+64
-12
lines changed

mp_api/client/routes/materials/summary.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def search( # noqa: D417
5656
poisson_ratio: tuple[float, float] | None = None,
5757
possible_species: list[str] | None = None,
5858
shape_factor: tuple[float, float] | None = None,
59-
spacegroup_number: int | None = None,
60-
spacegroup_symbol: str | None = None,
59+
spacegroup_number: int | list[int] | None = None,
60+
spacegroup_symbol: str | list[str] | None = None,
6161
surface_energy_anisotropy: tuple[float, float] | None = None,
6262
theoretical: bool | None = None,
6363
total_energy: tuple[float, float] | None = None,
@@ -81,7 +81,7 @@ def search( # noqa: D417
8181
band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider.
8282
chemsys (str, List[str]): A chemical system or list of chemical systems
8383
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]).
84-
crystal_system (CrystalSystem): Crystal system of material.
84+
crystal_system (CrystalSystem or list[CrystalSystem]): Crystal system(s) of the materials.
8585
density (Tuple[float,float]): Minimum and maximum density to consider.
8686
deprecated (bool): Whether the material is tagged as deprecated.
8787
e_electronic (Tuple[float,float]): Minimum and maximum electronic dielectric constant to consider.
@@ -128,8 +128,8 @@ def search( # noqa: D417
128128
poisson_ratio (Tuple[float,float]): Minimum and maximum value to consider for Poisson's ratio.
129129
possible_species (List(str)): List of element symbols appended with oxidation states. (e.g. Cr2+,O2-)
130130
shape_factor (Tuple[float,float]): Minimum and maximum shape factor values to consider.
131-
spacegroup_number (int): Space group number of material.
132-
spacegroup_symbol (str): Space group symbol of the material in international short symbol notation.
131+
spacegroup_number (int or list[int]): Space group number(s) of materials.
132+
spacegroup_symbol (str or list[str]): Space group symbol(s) of the materials in international short symbol notation.
133133
surface_energy_anisotropy (Tuple[float,float]): Minimum and maximum surface energy anisotropy values
134134
to consider.
135135
theoretical: (bool): Whether the material is theoretical.
@@ -319,13 +319,25 @@ def _csrc(x):
319319
if possible_species is not None:
320320
query_params.update({"possible_species": ",".join(possible_species)})
321321

322-
query_params.update(
323-
{
324-
"crystal_system": crystal_system,
325-
"spacegroup_number": spacegroup_number,
326-
"spacegroup_symbol": spacegroup_symbol,
327-
}
328-
)
322+
symm_cardinality = {
323+
"crystal_system": 7,
324+
"spacegroup_number": 230,
325+
"spacegroup_symbol": 230,
326+
}
327+
for k, cardinality in symm_cardinality.items():
328+
if hasattr(symm_vals := locals().get(k), "__len__") and not isinstance(
329+
symm_vals, str
330+
):
331+
if len(symm_vals) < cardinality // 2:
332+
query_params.update({k: ",".join(str(v) for v in symm_vals)})
333+
else:
334+
raise ValueError(
335+
f"Querying `{k}` by a list of values is only "
336+
f"supported for up to {cardinality//2 - 1} values. "
337+
f"For your query, retrieve all data first and then filter on `{k}`."
338+
)
339+
else:
340+
query_params.update({k: symm_vals})
329341

330342
if is_stable is not None:
331343
query_params.update({"is_stable": is_stable})

tests/materials/test_summary.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,43 @@ def test_client():
7272
custom_field_tests=custom_field_tests,
7373
sub_doc_fields=[],
7474
)
75+
76+
77+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
78+
def test_list_like_input():
79+
search_method = SummaryRester().search
80+
81+
# These are specifically chosen for the low representation in MP
82+
# Specifically, these are the four least-represented space groups
83+
# with at least one member
84+
sparse_sgn = (93, 101, 172, 179, 211)
85+
docs_by_number = search_method(
86+
spacegroup_number=sparse_sgn, fields=["material_id", "symmetry"]
87+
)
88+
assert {doc.symmetry.number for doc in docs_by_number} == set(sparse_sgn)
89+
90+
sparse_symbols = {doc.symmetry.symbol for doc in docs_by_number}
91+
docs_by_symbol = search_method(
92+
spacegroup_symbol=sparse_symbols, fields=["material_id", "symmetry"]
93+
)
94+
assert {doc.symmetry.symbol for doc in docs_by_symbol} == sparse_symbols
95+
assert {doc.material_id for doc in docs_by_symbol} == {
96+
doc.material_id for doc in docs_by_number
97+
}
98+
99+
# also chosen for very low document count
100+
crys_sys = ["Hexagonal", "Cubic"]
101+
assert {
102+
doc.symmetry.crystal_system
103+
for doc in search_method(elements=["Ar"], crystal_system=crys_sys)
104+
} == set(crys_sys)
105+
106+
# should fail - we don't support querying by so many list values
107+
with pytest.raises(ValueError, match="retrieve all data first and then filter"):
108+
_ = search_method(spacegroup_number=list(range(1, 231)))
109+
110+
with pytest.raises(ValueError, match="retrieve all data first and then filter"):
111+
_ = search_method(spacegroup_number=["null" for _ in range(230)])
112+
113+
with pytest.raises(ValueError, match="retrieve all data first and then filter"):
114+
_ = search_method(crystal_system=list(CrystalSystem))

0 commit comments

Comments
 (0)