33import itertools
44import os
55import warnings
6+ from collections import defaultdict
67from functools import cache , lru_cache
78from typing import TYPE_CHECKING
89
@@ -437,20 +438,15 @@ def get_task_ids_associated_with_material_id(
437438 if not tasks :
438439 return []
439440
440- calculations = (
441- tasks [0 ].calc_types # type: ignore
442- if self .use_document_model
443- else tasks [0 ]["calc_types" ] # type: ignore
444- )
441+ calculations = tasks [0 ]["calc_types" ]
445442
446443 if calc_types :
447444 return [
448445 task
449446 for task , calc_type in calculations .items ()
450447 if calc_type in calc_types
451448 ]
452- else :
453- return list (calculations .keys ())
449+ return list (calculations .keys ())
454450
455451 def get_structure_by_material_id (
456452 self , material_id : str , final : bool = True , conventional_unit_cell : bool = False
@@ -552,11 +548,7 @@ def get_material_id_references(self, material_id: str) -> list[str]:
552548 List of BibTeX references ([str])
553549 """
554550 docs = self .materials .provenance .search (material_ids = material_id )
555-
556- if not docs :
557- return []
558-
559- return docs [0 ].references if self .use_document_model else docs [0 ]["references" ] # type: ignore
551+ return docs [0 ]["references" ] if docs else []
560552
561553 def get_material_ids (
562554 self ,
@@ -571,17 +563,16 @@ def get_material_ids(
571563 Returns:
572564 List of all materials ids ([MPID])
573565 """
566+ inp_k = "formula"
574567 if isinstance (chemsys_formula , list ) or (
575568 isinstance (chemsys_formula , str ) and "-" in chemsys_formula
576569 ):
577- input_params = {"chemsys" : chemsys_formula }
578- else :
579- input_params = {"formula" : chemsys_formula }
570+ inp_k = "chemsys"
580571
581572 return sorted (
582- doc . material_id if self . use_document_model else doc ["material_id" ] # type: ignore
573+ doc ["material_id" ]
583574 for doc in self .materials .search (
584- ** input_params , # type: ignore
575+ ** { inp_k : chemsys_formula },
585576 all_fields = False ,
586577 fields = ["material_id" ],
587578 )
@@ -614,10 +605,8 @@ def get_structures(
614605 all_fields = False ,
615606 fields = ["structure" ],
616607 )
617- if not self .use_document_model :
618- return [doc ["structure" ] for doc in docs ] # type: ignore
619608
620- return [doc . structure for doc in docs ] # type: ignore
609+ return [doc [ " structure" ] for doc in docs ]
621610 else :
622611 structures = []
623612
@@ -626,12 +615,7 @@ def get_structures(
626615 all_fields = False ,
627616 fields = ["initial_structures" ],
628617 ):
629- initial_structures = (
630- doc .initial_structures # type: ignore
631- if self .use_document_model
632- else doc ["initial_structures" ] # type: ignore
633- )
634- structures .extend (initial_structures )
618+ structures .extend (doc ["initial_structures" ])
635619
636620 return structures
637621
@@ -736,7 +720,7 @@ def get_entries(
736720 if additional_criteria :
737721 input_params = {** input_params , ** additional_criteria }
738722
739- entries = []
723+ entries : set [ ComputedStructureEntry ] = set ()
740724
741725 fields = (
742726 ["entries" , "thermo_type" ]
@@ -751,24 +735,17 @@ def get_entries(
751735 )
752736
753737 for doc in docs :
754- entry_list = (
755- doc .entries .values () # type: ignore
756- if self .use_document_model
757- else doc ["entries" ].values () # type: ignore
758- )
738+ entry_list = doc ["entries" ].values ()
759739 for entry in entry_list :
760- entry_dict : dict = entry .as_dict () if self . monty_decode else entry # type: ignore
740+ entry_dict : dict = entry .as_dict () if hasattr ( entry , "as_dict" ) else entry # type: ignore
761741 if not compatible_only :
762742 entry_dict ["correction" ] = 0.0
763743 entry_dict ["energy_adjustments" ] = []
764744
765745 if property_data :
766- for property in property_data :
767- entry_dict ["data" ][property ] = (
768- doc .model_dump ()[property ] # type: ignore
769- if self .use_document_model
770- else doc [property ] # type: ignore
771- )
746+ entry_dict ["data" ] = {
747+ property : doc [property ] for property in property_data
748+ }
772749
773750 if conventional_unit_cell :
774751 entry_struct = Structure .from_dict (entry_dict ["structure" ])
@@ -789,15 +766,10 @@ def get_entries(
789766 if "n_atoms" in correction :
790767 correction ["n_atoms" ] *= site_ratio
791768
792- entry = (
793- ComputedStructureEntry .from_dict (entry_dict )
794- if self .monty_decode
795- else entry_dict
796- )
769+ # Need to store object to permit de-duplication
770+ entries .add (ComputedStructureEntry .from_dict (entry_dict ))
797771
798- entries .append (entry )
799-
800- return entries
772+ return [e if self .monty_decode else e .as_dict () for e in entries ]
801773
802774 def get_pourbaix_entries (
803775 self ,
@@ -1328,9 +1300,7 @@ def get_wulff_shape(self, material_id: str):
13281300 if not doc :
13291301 return None
13301302
1331- surfaces : list = (
1332- doc [0 ].surfaces if self .use_document_model else doc [0 ]["surfaces" ] # type: ignore
1333- )
1303+ surfaces : list = doc [0 ]["surfaces" ]
13341304
13351305 lattice = (
13361306 SpacegroupAnalyzer (structure ).get_conventional_standard_structure ().lattice
@@ -1400,17 +1370,8 @@ def get_charge_density_from_material_id(
14001370 if len (results ) == 0 :
14011371 return None
14021372
1403- latest_doc = max ( # type: ignore
1404- results ,
1405- key = lambda x : (
1406- x .last_updated # type: ignore
1407- if self .use_document_model
1408- else x ["last_updated" ]
1409- ), # type: ignore
1410- )
1411- task_id = (
1412- latest_doc .task_id if self .use_document_model else latest_doc ["task_id" ]
1413- )
1373+ latest_doc = max (results , key = lambda x : x ["last_updated" ])
1374+ task_id = latest_doc ["task_id" ]
14141375 return self .get_charge_density_from_task_id (task_id , inc_task_doc )
14151376
14161377 def get_download_info (self , material_ids , calc_types = None , file_patterns = None ):
@@ -1432,20 +1393,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
14321393 else []
14331394 )
14341395
1435- meta = {}
1396+ meta = defaultdict ( list )
14361397 for doc in self .materials .search ( # type: ignore
14371398 task_ids = material_ids ,
14381399 fields = ["calc_types" , "deprecated_tasks" , "material_id" ],
14391400 ):
1440- doc_dict : dict = doc .model_dump () if self .use_document_model else doc # type: ignore
1441- for task_id , calc_type in doc_dict ["calc_types" ].items ():
1401+ for task_id , calc_type in doc ["calc_types" ].items ():
14421402 if calc_types and calc_type not in calc_types :
14431403 continue
1444- mp_id = doc_dict ["material_id" ]
1445- if meta .get (mp_id ) is None :
1446- meta [mp_id ] = [{"task_id" : task_id , "calc_type" : calc_type }]
1447- else :
1448- meta [mp_id ].append ({"task_id" : task_id , "calc_type" : calc_type })
1404+ mp_id = doc ["material_id" ]
1405+ meta [mp_id ].append ({"task_id" : task_id , "calc_type" : calc_type })
1406+
14491407 if not meta :
14501408 raise ValueError (f"No tasks found for material id { material_ids } ." )
14511409
0 commit comments