From aca3734a0c341c89d7fc02f4bf5bdfcbd47155b7 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Mon, 13 Oct 2025 16:14:56 -0700 Subject: [PATCH 1/2] add query by list of space group symbol, number, or crystal system + test --- mp_api/client/routes/materials/summary.py | 30 ++++++++++++++------- tests/materials/test_summary.py | 33 +++++++++++++++++++++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index 1aec0c75a..09c1e2766 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -56,8 +56,8 @@ def search( # noqa: D417 poisson_ratio: tuple[float, float] | None = None, possible_species: list[str] | None = None, shape_factor: tuple[float, float] | None = None, - spacegroup_number: int | None = None, - spacegroup_symbol: str | None = None, + spacegroup_number: int | list[int] | None = None, + spacegroup_symbol: str | list[str] | None = None, surface_energy_anisotropy: tuple[float, float] | None = None, theoretical: bool | None = None, total_energy: tuple[float, float] | None = None, @@ -319,13 +319,25 @@ def _csrc(x): if possible_species is not None: query_params.update({"possible_species": ",".join(possible_species)}) - query_params.update( - { - "crystal_system": crystal_system, - "spacegroup_number": spacegroup_number, - "spacegroup_symbol": spacegroup_symbol, - } - ) + symm_cardinality = { + "crystal_system": 7, + "spacegroup_number": 230, + "spacegroup_symbol": 230, + } + for k, cardinality in symm_cardinality.items(): + if hasattr(symm_vals := locals().get(k), "__len__") and not isinstance( + symm_vals, str + ): + if len(symm_vals) < cardinality // 2: + query_params.update({k: ",".join(str(v) for v in symm_vals)}) + else: + raise ValueError( + f"Querying `{k}` by a list of values is only " + f"supported for up to {cardinality//2 - 1} values. " + f"For your query, retrieve all data first and then filter on `{k}`." + ) + else: + query_params.update({k: symm_vals}) if is_stable is not None: query_params.update({"is_stable": is_stable}) diff --git a/tests/materials/test_summary.py b/tests/materials/test_summary.py index 36e1ac0d9..51a9f183d 100644 --- a/tests/materials/test_summary.py +++ b/tests/materials/test_summary.py @@ -72,3 +72,36 @@ def test_client(): custom_field_tests=custom_field_tests, sub_doc_fields=[], ) + + +@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.") +def test_list_like_input(): + search_method = SummaryRester().search + + # These are specifically chosen for the low representation in MP + # Specifically, these are the four least-represented space groups + # with at least one member + sparse_sgn = (93, 101, 172, 179, 211) + docs_by_number = search_method( + spacegroup_number=sparse_sgn, fields=["material_id", "symmetry"] + ) + assert {doc.symmetry.number for doc in docs_by_number} == set(sparse_sgn) + + sparse_symbols = {doc.symmetry.symbol for doc in docs_by_number} + docs_by_symbol = search_method( + spacegroup_symbol=sparse_symbols, fields=["material_id", "symmetry"] + ) + assert {doc.symmetry.symbol for doc in docs_by_symbol} == sparse_symbols + assert {doc.material_id for doc in docs_by_symbol} == { + doc.material_id for doc in docs_by_number + } + + # should fail - we don't support querying by so many list values + with pytest.raises(ValueError, match="retrieve all data first and then filter"): + _ = search_method(spacegroup_number=list(range(1, 231))) + + with pytest.raises(ValueError, match="retrieve all data first and then filter"): + _ = search_method(spacegroup_number=["null" for _ in range(230)]) + + with pytest.raises(ValueError, match="retrieve all data first and then filter"): + _ = search_method(crystal_system=list(CrystalSystem)) From c1b872e7684a39730c61c570771308a8f1c50972 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Mon, 13 Oct 2025 16:25:04 -0700 Subject: [PATCH 2/2] add test for crystal_system + docstr --- mp_api/client/routes/materials/summary.py | 6 +++--- tests/materials/test_summary.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index 09c1e2766..e04d71f3e 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -81,7 +81,7 @@ def search( # noqa: D417 band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider. chemsys (str, List[str]): A chemical system or list of chemical systems (e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]). - crystal_system (CrystalSystem): Crystal system of material. + crystal_system (CrystalSystem or list[CrystalSystem]): Crystal system(s) of the materials. density (Tuple[float,float]): Minimum and maximum density to consider. deprecated (bool): Whether the material is tagged as deprecated. e_electronic (Tuple[float,float]): Minimum and maximum electronic dielectric constant to consider. @@ -128,8 +128,8 @@ def search( # noqa: D417 poisson_ratio (Tuple[float,float]): Minimum and maximum value to consider for Poisson's ratio. possible_species (List(str)): List of element symbols appended with oxidation states. (e.g. Cr2+,O2-) shape_factor (Tuple[float,float]): Minimum and maximum shape factor values to consider. - spacegroup_number (int): Space group number of material. - spacegroup_symbol (str): Space group symbol of the material in international short symbol notation. + spacegroup_number (int or list[int]): Space group number(s) of materials. + spacegroup_symbol (str or list[str]): Space group symbol(s) of the materials in international short symbol notation. surface_energy_anisotropy (Tuple[float,float]): Minimum and maximum surface energy anisotropy values to consider. theoretical: (bool): Whether the material is theoretical. diff --git a/tests/materials/test_summary.py b/tests/materials/test_summary.py index 51a9f183d..c8f86bef0 100644 --- a/tests/materials/test_summary.py +++ b/tests/materials/test_summary.py @@ -96,6 +96,13 @@ def test_list_like_input(): doc.material_id for doc in docs_by_number } + # also chosen for very low document count + crys_sys = ["Hexagonal", "Cubic"] + assert { + doc.symmetry.crystal_system + for doc in search_method(elements=["Ar"], crystal_system=crys_sys) + } == set(crys_sys) + # should fail - we don't support querying by so many list values with pytest.raises(ValueError, match="retrieve all data first and then filter"): _ = search_method(spacegroup_number=list(range(1, 231)))