Skip to content

Commit 33b787f

Browse files
authored
Merge branch 'main' into deltalake
2 parents 7195adf + 5ecebec commit 33b787f

File tree

11 files changed

+154
-196
lines changed

11 files changed

+154
-196
lines changed

mp_api/client/core/client.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@
6363
SETTINGS = MAPIClientSettings() # type: ignore
6464

6565

66+
class _DictLikeAccess(BaseModel):
67+
"""Define a pydantic mix-in which permits dict-like access to model fields."""
68+
69+
def __getitem__(self, item: str) -> Any:
70+
"""Return `item` if a valid model field, otherwise raise an exception."""
71+
if item in self.__class__.model_fields:
72+
return getattr(self, item)
73+
raise AttributeError(f"{self.__class__.__name__} has no model field `{item}`.")
74+
75+
def get(self, item: str, default: Any = None) -> Any:
76+
"""Return a model field `item`, or `default` if it doesn't exist."""
77+
try:
78+
return self.__getitem__(item)
79+
except AttributeError:
80+
return default
81+
82+
6683
class BaseRester:
6784
"""Base client class with core stubs."""
6885

@@ -431,13 +448,9 @@ def _query_resource(
431448
if use_document_model is None:
432449
use_document_model = self.use_document_model
433450

434-
if timeout is None:
435-
timeout = self.timeout
451+
timeout = self.timeout if timeout is None else timeout
436452

437-
if criteria:
438-
criteria = {k: v for k, v in criteria.items() if v is not None}
439-
else:
440-
criteria = {}
453+
criteria = {k: v for k, v in (criteria or {}).items() if v is not None}
441454

442455
# Query s3 if no query is passed and all documents are asked for
443456
# TODO also skip fields set to same as their default
@@ -1238,6 +1251,7 @@ def _generate_returned_model(
12381251
# TODO fields_not_requested is not the same as unset_fields
12391252
# i.e. field could be requested but not available in the raw doc
12401253
fields_not_requested=(list[str], unset_fields),
1254+
__base__=_DictLikeAccess,
12411255
__doc__=".".join(
12421256
[
12431257
getattr(self.document_model, k, "")

mp_api/client/mprester.py

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import os
55
import warnings
6+
from collections import defaultdict
67
from functools import cache, lru_cache
78
from typing import TYPE_CHECKING
89

@@ -437,20 +438,15 @@ def get_task_ids_associated_with_material_id(
437438
if not tasks:
438439
return []
439440

440-
calculations = (
441-
tasks[0].calc_types # type: ignore
442-
if self.use_document_model
443-
else tasks[0]["calc_types"] # type: ignore
444-
)
441+
calculations = tasks[0]["calc_types"]
445442

446443
if calc_types:
447444
return [
448445
task
449446
for task, calc_type in calculations.items()
450447
if calc_type in calc_types
451448
]
452-
else:
453-
return list(calculations.keys())
449+
return list(calculations.keys())
454450

455451
def get_structure_by_material_id(
456452
self, material_id: str, final: bool = True, conventional_unit_cell: bool = False
@@ -552,11 +548,7 @@ def get_material_id_references(self, material_id: str) -> list[str]:
552548
List of BibTeX references ([str])
553549
"""
554550
docs = self.materials.provenance.search(material_ids=material_id)
555-
556-
if not docs:
557-
return []
558-
559-
return docs[0].references if self.use_document_model else docs[0]["references"] # type: ignore
551+
return docs[0]["references"] if docs else []
560552

561553
def get_material_ids(
562554
self,
@@ -571,17 +563,16 @@ def get_material_ids(
571563
Returns:
572564
List of all materials ids ([MPID])
573565
"""
566+
inp_k = "formula"
574567
if isinstance(chemsys_formula, list) or (
575568
isinstance(chemsys_formula, str) and "-" in chemsys_formula
576569
):
577-
input_params = {"chemsys": chemsys_formula}
578-
else:
579-
input_params = {"formula": chemsys_formula}
570+
inp_k = "chemsys"
580571

581572
return sorted(
582-
doc.material_id if self.use_document_model else doc["material_id"] # type: ignore
573+
doc["material_id"]
583574
for doc in self.materials.search(
584-
**input_params, # type: ignore
575+
**{inp_k: chemsys_formula},
585576
all_fields=False,
586577
fields=["material_id"],
587578
)
@@ -614,10 +605,8 @@ def get_structures(
614605
all_fields=False,
615606
fields=["structure"],
616607
)
617-
if not self.use_document_model:
618-
return [doc["structure"] for doc in docs] # type: ignore
619608

620-
return [doc.structure for doc in docs] # type: ignore
609+
return [doc["structure"] for doc in docs]
621610
else:
622611
structures = []
623612

@@ -626,12 +615,7 @@ def get_structures(
626615
all_fields=False,
627616
fields=["initial_structures"],
628617
):
629-
initial_structures = (
630-
doc.initial_structures # type: ignore
631-
if self.use_document_model
632-
else doc["initial_structures"] # type: ignore
633-
)
634-
structures.extend(initial_structures)
618+
structures.extend(doc["initial_structures"])
635619

636620
return structures
637621

@@ -736,7 +720,7 @@ def get_entries(
736720
if additional_criteria:
737721
input_params = {**input_params, **additional_criteria}
738722

739-
entries = []
723+
entries: set[ComputedStructureEntry] = set()
740724

741725
fields = (
742726
["entries", "thermo_type"]
@@ -751,24 +735,17 @@ def get_entries(
751735
)
752736

753737
for doc in docs:
754-
entry_list = (
755-
doc.entries.values() # type: ignore
756-
if self.use_document_model
757-
else doc["entries"].values() # type: ignore
758-
)
738+
entry_list = doc["entries"].values()
759739
for entry in entry_list:
760-
entry_dict: dict = entry.as_dict() if self.monty_decode else entry # type: ignore
740+
entry_dict: dict = entry.as_dict() if hasattr(entry, "as_dict") else entry # type: ignore
761741
if not compatible_only:
762742
entry_dict["correction"] = 0.0
763743
entry_dict["energy_adjustments"] = []
764744

765745
if property_data:
766-
for property in property_data:
767-
entry_dict["data"][property] = (
768-
doc.model_dump()[property] # type: ignore
769-
if self.use_document_model
770-
else doc[property] # type: ignore
771-
)
746+
entry_dict["data"] = {
747+
property: doc[property] for property in property_data
748+
}
772749

773750
if conventional_unit_cell:
774751
entry_struct = Structure.from_dict(entry_dict["structure"])
@@ -789,15 +766,10 @@ def get_entries(
789766
if "n_atoms" in correction:
790767
correction["n_atoms"] *= site_ratio
791768

792-
entry = (
793-
ComputedStructureEntry.from_dict(entry_dict)
794-
if self.monty_decode
795-
else entry_dict
796-
)
769+
# Need to store object to permit de-duplication
770+
entries.add(ComputedStructureEntry.from_dict(entry_dict))
797771

798-
entries.append(entry)
799-
800-
return entries
772+
return [e if self.monty_decode else e.as_dict() for e in entries]
801773

802774
def get_pourbaix_entries(
803775
self,
@@ -1328,9 +1300,7 @@ def get_wulff_shape(self, material_id: str):
13281300
if not doc:
13291301
return None
13301302

1331-
surfaces: list = (
1332-
doc[0].surfaces if self.use_document_model else doc[0]["surfaces"] # type: ignore
1333-
)
1303+
surfaces: list = doc[0]["surfaces"]
13341304

13351305
lattice = (
13361306
SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice
@@ -1400,17 +1370,8 @@ def get_charge_density_from_material_id(
14001370
if len(results) == 0:
14011371
return None
14021372

1403-
latest_doc = max( # type: ignore
1404-
results,
1405-
key=lambda x: (
1406-
x.last_updated # type: ignore
1407-
if self.use_document_model
1408-
else x["last_updated"]
1409-
), # type: ignore
1410-
)
1411-
task_id = (
1412-
latest_doc.task_id if self.use_document_model else latest_doc["task_id"]
1413-
)
1373+
latest_doc = max(results, key=lambda x: x["last_updated"])
1374+
task_id = latest_doc["task_id"]
14141375
return self.get_charge_density_from_task_id(task_id, inc_task_doc)
14151376

14161377
def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
@@ -1432,20 +1393,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
14321393
else []
14331394
)
14341395

1435-
meta = {}
1396+
meta = defaultdict(list)
14361397
for doc in self.materials.search( # type: ignore
14371398
task_ids=material_ids,
14381399
fields=["calc_types", "deprecated_tasks", "material_id"],
14391400
):
1440-
doc_dict: dict = doc.model_dump() if self.use_document_model else doc # type: ignore
1441-
for task_id, calc_type in doc_dict["calc_types"].items():
1401+
for task_id, calc_type in doc["calc_types"].items():
14421402
if calc_types and calc_type not in calc_types:
14431403
continue
1444-
mp_id = doc_dict["material_id"]
1445-
if meta.get(mp_id) is None:
1446-
meta[mp_id] = [{"task_id": task_id, "calc_type": calc_type}]
1447-
else:
1448-
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
1404+
mp_id = doc["material_id"]
1405+
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
1406+
14491407
if not meta:
14501408
raise ValueError(f"No tasks found for material id {material_ids}.")
14511409

mp_api/client/routes/materials/electronic_structure.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -276,61 +276,47 @@ def get_bandstructure_from_material_id(
276276
if not bs_doc:
277277
raise MPRestError("No electronic structure data found.")
278278

279-
bs_data = (
280-
bs_doc[0].bandstructure # type: ignore
281-
if self.use_document_model
282-
else bs_doc[0]["bandstructure"] # type: ignore
283-
)
284-
285-
if bs_data is None:
279+
if (bs_data := bs_doc[0]["bandstructure"]) is None:
286280
raise MPRestError(
287281
f"No {path_type.value} band structure data found for {material_id}"
288282
)
289-
else:
290-
bs_data: dict = (
291-
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
292-
)
293283

294-
if bs_data.get(path_type.value, None):
295-
bs_task_id = bs_data[path_type.value]["task_id"]
296-
else:
284+
bs_data: dict = (
285+
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
286+
)
287+
288+
if bs_data.get(path_type.value, None) is None:
297289
raise MPRestError(
298290
f"No {path_type.value} band structure data found for {material_id}"
299291
)
300-
else:
301-
bs_doc = es_rester.search(material_ids=material_id, fields=["dos"])
292+
bs_task_id = bs_data[path_type.value]["task_id"]
302293

303-
if not bs_doc:
294+
else:
295+
if not (
296+
bs_doc := es_rester.search(material_ids=material_id, fields=["dos"])
297+
):
304298
raise MPRestError("No electronic structure data found.")
305299

306-
bs_data = (
307-
bs_doc[0].dos # type: ignore
308-
if self.use_document_model
309-
else bs_doc[0]["dos"] # type: ignore
310-
)
311-
312-
if bs_data is None:
300+
if (bs_data := bs_doc[0]["dos"]) is None:
313301
raise MPRestError(
314302
f"No uniform band structure data found for {material_id}"
315303
)
316-
else:
317-
bs_data: dict = (
318-
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
319-
)
320304

321-
if bs_data.get("total", None):
322-
bs_task_id = bs_data["total"]["1"]["task_id"]
323-
else:
305+
bs_data: dict = (
306+
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
307+
)
308+
309+
if bs_data.get("total", None) is None:
324310
raise MPRestError(
325311
f"No uniform band structure data found for {material_id}"
326312
)
313+
bs_task_id = bs_data["total"]["1"]["task_id"]
327314

328315
bs_obj = self.get_bandstructure_from_task_id(bs_task_id)
329316

330317
if bs_obj:
331318
return bs_obj
332-
else:
333-
raise MPRestError("No band structure object found.")
319+
raise MPRestError("No band structure object found.")
334320

335321

336322
class DosRester(BaseRester):
@@ -456,22 +442,16 @@ def get_dos_from_material_id(self, material_id: str):
456442
mute_progress_bars=self.mute_progress_bars,
457443
)
458444

459-
dos_doc = es_rester.search(material_ids=material_id, fields=["dos"])
460-
if not dos_doc:
445+
if not (dos_doc := es_rester.search(material_ids=material_id, fields=["dos"])):
461446
return None
462447

463-
dos_data: dict = (
464-
dos_doc[0].model_dump() if self.use_document_model else dos_doc[0] # type: ignore
465-
)
466-
467-
if dos_data["dos"]:
468-
dos_task_id = dos_data["dos"]["total"]["1"]["task_id"]
469-
else:
448+
if not (dos_data := dos_doc[0].get("dos")):
470449
raise MPRestError(f"No density of states data found for {material_id}")
471450

472-
dos_obj = self.get_dos_from_task_id(dos_task_id)
473-
474-
if dos_obj:
451+
dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[
452+
"total"
453+
]["1"]["task_id"]
454+
if dos_obj := self.get_dos_from_task_id(dos_task_id):
475455
return dos_obj
476-
else:
477-
raise MPRestError("No density of states object found.")
456+
457+
raise MPRestError("No density of states object found.")

mp_api/client/routes/materials/materials.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,17 @@ def get_structure_by_material_id(
126126

127127
response = self.search(material_ids=material_id, fields=[field])
128128

129-
if response:
130-
response = (
131-
response[0].model_dump() if self.use_document_model else response[0] # type: ignore
132-
)
129+
if response and response[0]:
130+
response = response[0]
131+
# Ensure that return type is a Structure regardless of `monty_decode` or `model_dump` output
132+
if isinstance(response[field], dict):
133+
response[field] = Structure.from_dict(response[field])
134+
elif isinstance(response[field], list) and any(
135+
isinstance(struct, dict) for struct in response[field]
136+
):
137+
response[field] = [
138+
Structure.from_dict(struct) for struct in response[field]
139+
]
133140

134141
return response[field] if response else response # type: ignore
135142

@@ -305,7 +312,4 @@ def find_structure(
305312
)
306313
return results # type: ignore
307314

308-
if results:
309-
return results[0]["material_id"]
310-
else:
311-
return []
315+
return results[0]["material_id"] if (results and results[0]) else []

0 commit comments

Comments
 (0)