Skip to content

Commit 17d66f0

Browse files
Merge remote-tracking branch 'upstream/main' into mcp
2 parents 384c2a9 + 52a3c57 commit 17d66f0

13 files changed

+203
-247
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232

3333
- name: Install Python dependencies
3434
run: |
35-
python -m pip install --upgrade pip
35+
python -m pip install --upgrade "pip<25.3"
3636
pip install -r requirements/requirements-${{ matrix.os }}_py${{ matrix.python-version }}.txt
3737
pip install -r requirements/requirements-${{ matrix.os }}_py${{ matrix.python-version }}_extras.txt
3838

.github/workflows/upgrade_dependencies.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ jobs:
2323
- name: Upgrade Python dependencies
2424
shell: bash
2525
run: |
26-
python -m pip install --upgrade pip pip-tools
27-
python -m piptools compile -q --upgrade --resolver=backtracking -o requirements/requirements-${{ matrix.os }}_py${{ matrix.python-version }}.txt pyproject.toml
28-
python -m piptools compile -q --upgrade --resolver=backtracking --all-extras -o requirements/requirements-${{ matrix.os }}_py${{ matrix.python-version }}_extras.txt pyproject.toml
26+
python -m pip install --upgrade "pip<25.3" pip-tools
27+
python -m piptools compile -q --upgrade -o requirements/requirements-${{ matrix.os }}_py${{ matrix.python-version }}.txt pyproject.toml
28+
python -m piptools compile -q --upgrade --all-extras -o requirements/requirements-${{ matrix.os }}_py${{ matrix.python-version }}_extras.txt pyproject.toml
2929
- name: Detect changes
3030
id: changes
3131
shell: bash

mp_api/client/core/client.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
from importlib.metadata import PackageNotFoundError, version
1919
from json import JSONDecodeError
2020
from math import ceil
21-
from typing import (
22-
TYPE_CHECKING,
23-
ForwardRef,
24-
Optional,
25-
get_args,
26-
)
21+
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
2722
from urllib.parse import quote, urljoin
2823

2924
import requests
@@ -64,6 +59,23 @@
6459
SETTINGS = MAPIClientSettings() # type: ignore
6560

6661

62+
class _DictLikeAccess(BaseModel):
63+
"""Define a pydantic mix-in which permits dict-like access to model fields."""
64+
65+
def __getitem__(self, item: str) -> Any:
66+
"""Return `item` if a valid model field, otherwise raise an exception."""
67+
if item in self.__class__.model_fields:
68+
return getattr(self, item)
69+
raise AttributeError(f"{self.__class__.__name__} has no model field `{item}`.")
70+
71+
def get(self, item: str, default: Any = None) -> Any:
72+
"""Return a model field `item`, or `default` if it doesn't exist."""
73+
try:
74+
return self.__getitem__(item)
75+
except AttributeError:
76+
return default
77+
78+
6779
class BaseRester:
6880
"""Base client class with core stubs."""
6981

@@ -427,13 +439,9 @@ def _query_resource(
427439
if use_document_model is None:
428440
use_document_model = self.use_document_model
429441

430-
if timeout is None:
431-
timeout = self.timeout
442+
timeout = self.timeout if timeout is None else timeout
432443

433-
if criteria:
434-
criteria = {k: v for k, v in criteria.items() if v is not None}
435-
else:
436-
criteria = {}
444+
criteria = {k: v for k, v in (criteria or {}).items() if v is not None}
437445

438446
# Query s3 if no query is passed and all documents are asked for
439447
# TODO also skip fields set to same as their default
@@ -1080,6 +1088,7 @@ def _generate_returned_model(
10801088
# TODO fields_not_requested is not the same as unset_fields
10811089
# i.e. field could be requested but not available in the raw doc
10821090
fields_not_requested=(list[str], unset_fields),
1091+
__base__=_DictLikeAccess,
10831092
__doc__=".".join(
10841093
[
10851094
getattr(self.document_model, k, "")

mp_api/client/mprester.py

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import os
55
import warnings
6+
from collections import defaultdict
67
from functools import cache, lru_cache
78
from typing import TYPE_CHECKING
89

@@ -419,20 +420,15 @@ def get_task_ids_associated_with_material_id(
419420
if not tasks:
420421
return []
421422

422-
calculations = (
423-
tasks[0].calc_types # type: ignore
424-
if self.use_document_model
425-
else tasks[0]["calc_types"] # type: ignore
426-
)
423+
calculations = tasks[0]["calc_types"]
427424

428425
if calc_types:
429426
return [
430427
task
431428
for task, calc_type in calculations.items()
432429
if calc_type in calc_types
433430
]
434-
else:
435-
return list(calculations.keys())
431+
return list(calculations.keys())
436432

437433
def get_structure_by_material_id(
438434
self, material_id: str, final: bool = True, conventional_unit_cell: bool = False
@@ -534,11 +530,7 @@ def get_material_id_references(self, material_id: str) -> list[str]:
534530
List of BibTeX references ([str])
535531
"""
536532
docs = self.materials.provenance.search(material_ids=material_id)
537-
538-
if not docs:
539-
return []
540-
541-
return docs[0].references if self.use_document_model else docs[0]["references"] # type: ignore
533+
return docs[0]["references"] if docs else []
542534

543535
def get_material_ids(
544536
self,
@@ -553,17 +545,16 @@ def get_material_ids(
553545
Returns:
554546
List of all materials ids ([MPID])
555547
"""
548+
inp_k = "formula"
556549
if isinstance(chemsys_formula, list) or (
557550
isinstance(chemsys_formula, str) and "-" in chemsys_formula
558551
):
559-
input_params = {"chemsys": chemsys_formula}
560-
else:
561-
input_params = {"formula": chemsys_formula}
552+
inp_k = "chemsys"
562553

563554
return sorted(
564-
doc.material_id if self.use_document_model else doc["material_id"] # type: ignore
555+
doc["material_id"]
565556
for doc in self.materials.search(
566-
**input_params, # type: ignore
557+
**{inp_k: chemsys_formula},
567558
all_fields=False,
568559
fields=["material_id"],
569560
)
@@ -596,10 +587,8 @@ def get_structures(
596587
all_fields=False,
597588
fields=["structure"],
598589
)
599-
if not self.use_document_model:
600-
return [doc["structure"] for doc in docs] # type: ignore
601590

602-
return [doc.structure for doc in docs] # type: ignore
591+
return [doc["structure"] for doc in docs]
603592
else:
604593
structures = []
605594

@@ -608,12 +597,7 @@ def get_structures(
608597
all_fields=False,
609598
fields=["initial_structures"],
610599
):
611-
initial_structures = (
612-
doc.initial_structures # type: ignore
613-
if self.use_document_model
614-
else doc["initial_structures"] # type: ignore
615-
)
616-
structures.extend(initial_structures)
600+
structures.extend(doc["initial_structures"])
617601

618602
return structures
619603

@@ -718,7 +702,7 @@ def get_entries(
718702
if additional_criteria:
719703
input_params = {**input_params, **additional_criteria}
720704

721-
entries = []
705+
entries: set[ComputedStructureEntry] = set()
722706

723707
fields = (
724708
["entries", "thermo_type"]
@@ -733,24 +717,17 @@ def get_entries(
733717
)
734718

735719
for doc in docs:
736-
entry_list = (
737-
doc.entries.values() # type: ignore
738-
if self.use_document_model
739-
else doc["entries"].values() # type: ignore
740-
)
720+
entry_list = doc["entries"].values()
741721
for entry in entry_list:
742-
entry_dict: dict = entry.as_dict() if self.monty_decode else entry # type: ignore
722+
entry_dict: dict = entry.as_dict() if hasattr(entry, "as_dict") else entry # type: ignore
743723
if not compatible_only:
744724
entry_dict["correction"] = 0.0
745725
entry_dict["energy_adjustments"] = []
746726

747727
if property_data:
748-
for property in property_data:
749-
entry_dict["data"][property] = (
750-
doc.model_dump()[property] # type: ignore
751-
if self.use_document_model
752-
else doc[property] # type: ignore
753-
)
728+
entry_dict["data"] = {
729+
property: doc[property] for property in property_data
730+
}
754731

755732
if conventional_unit_cell:
756733
entry_struct = Structure.from_dict(entry_dict["structure"])
@@ -771,15 +748,10 @@ def get_entries(
771748
if "n_atoms" in correction:
772749
correction["n_atoms"] *= site_ratio
773750

774-
entry = (
775-
ComputedStructureEntry.from_dict(entry_dict)
776-
if self.monty_decode
777-
else entry_dict
778-
)
751+
# Need to store object to permit de-duplication
752+
entries.add(ComputedStructureEntry.from_dict(entry_dict))
779753

780-
entries.append(entry)
781-
782-
return entries
754+
return [e if self.monty_decode else e.as_dict() for e in entries]
783755

784756
def get_pourbaix_entries(
785757
self,
@@ -1310,9 +1282,7 @@ def get_wulff_shape(self, material_id: str):
13101282
if not doc:
13111283
return None
13121284

1313-
surfaces: list = (
1314-
doc[0].surfaces if self.use_document_model else doc[0]["surfaces"] # type: ignore
1315-
)
1285+
surfaces: list = doc[0]["surfaces"]
13161286

13171287
lattice = (
13181288
SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice
@@ -1382,17 +1352,8 @@ def get_charge_density_from_material_id(
13821352
if len(results) == 0:
13831353
return None
13841354

1385-
latest_doc = max( # type: ignore
1386-
results,
1387-
key=lambda x: (
1388-
x.last_updated # type: ignore
1389-
if self.use_document_model
1390-
else x["last_updated"]
1391-
), # type: ignore
1392-
)
1393-
task_id = (
1394-
latest_doc.task_id if self.use_document_model else latest_doc["task_id"]
1395-
)
1355+
latest_doc = max(results, key=lambda x: x["last_updated"])
1356+
task_id = latest_doc["task_id"]
13961357
return self.get_charge_density_from_task_id(task_id, inc_task_doc)
13971358

13981359
def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
@@ -1414,20 +1375,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
14141375
else []
14151376
)
14161377

1417-
meta = {}
1378+
meta = defaultdict(list)
14181379
for doc in self.materials.search( # type: ignore
14191380
task_ids=material_ids,
14201381
fields=["calc_types", "deprecated_tasks", "material_id"],
14211382
):
1422-
doc_dict: dict = doc.model_dump() if self.use_document_model else doc # type: ignore
1423-
for task_id, calc_type in doc_dict["calc_types"].items():
1383+
for task_id, calc_type in doc["calc_types"].items():
14241384
if calc_types and calc_type not in calc_types:
14251385
continue
1426-
mp_id = doc_dict["material_id"]
1427-
if meta.get(mp_id) is None:
1428-
meta[mp_id] = [{"task_id": task_id, "calc_type": calc_type}]
1429-
else:
1430-
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
1386+
mp_id = doc["material_id"]
1387+
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
1388+
14311389
if not meta:
14321390
raise ValueError(f"No tasks found for material id {material_ids}.")
14331391

0 commit comments

Comments
 (0)