33import itertools
44import os
55import warnings
6+ from collections import defaultdict
67from functools import cache , lru_cache
78from typing import TYPE_CHECKING
89
@@ -419,20 +420,15 @@ def get_task_ids_associated_with_material_id(
419420 if not tasks :
420421 return []
421422
422- calculations = (
423- tasks [0 ].calc_types # type: ignore
424- if self .use_document_model
425- else tasks [0 ]["calc_types" ] # type: ignore
426- )
423+ calculations = tasks [0 ]["calc_types" ]
427424
428425 if calc_types :
429426 return [
430427 task
431428 for task , calc_type in calculations .items ()
432429 if calc_type in calc_types
433430 ]
434- else :
435- return list (calculations .keys ())
431+ return list (calculations .keys ())
436432
437433 def get_structure_by_material_id (
438434 self , material_id : str , final : bool = True , conventional_unit_cell : bool = False
@@ -534,11 +530,7 @@ def get_material_id_references(self, material_id: str) -> list[str]:
534530 List of BibTeX references ([str])
535531 """
536532 docs = self .materials .provenance .search (material_ids = material_id )
537-
538- if not docs :
539- return []
540-
541- return docs [0 ].references if self .use_document_model else docs [0 ]["references" ] # type: ignore
533+ return docs [0 ]["references" ] if docs else []
542534
543535 def get_material_ids (
544536 self ,
@@ -553,17 +545,16 @@ def get_material_ids(
553545 Returns:
554546 List of all materials ids ([MPID])
555547 """
548+ inp_k = "formula"
556549 if isinstance (chemsys_formula , list ) or (
557550 isinstance (chemsys_formula , str ) and "-" in chemsys_formula
558551 ):
559- input_params = {"chemsys" : chemsys_formula }
560- else :
561- input_params = {"formula" : chemsys_formula }
552+ inp_k = "chemsys"
562553
563554 return sorted (
564- doc . material_id if self . use_document_model else doc ["material_id" ] # type: ignore
555+ doc ["material_id" ]
565556 for doc in self .materials .search (
566- ** input_params , # type: ignore
557+ ** { inp_k : chemsys_formula },
567558 all_fields = False ,
568559 fields = ["material_id" ],
569560 )
@@ -596,10 +587,8 @@ def get_structures(
596587 all_fields = False ,
597588 fields = ["structure" ],
598589 )
599- if not self .use_document_model :
600- return [doc ["structure" ] for doc in docs ] # type: ignore
601590
602- return [doc . structure for doc in docs ] # type: ignore
591+ return [doc [ " structure" ] for doc in docs ]
603592 else :
604593 structures = []
605594
@@ -608,12 +597,7 @@ def get_structures(
608597 all_fields = False ,
609598 fields = ["initial_structures" ],
610599 ):
611- initial_structures = (
612- doc .initial_structures # type: ignore
613- if self .use_document_model
614- else doc ["initial_structures" ] # type: ignore
615- )
616- structures .extend (initial_structures )
600+ structures .extend (doc ["initial_structures" ])
617601
618602 return structures
619603
@@ -718,7 +702,7 @@ def get_entries(
718702 if additional_criteria :
719703 input_params = {** input_params , ** additional_criteria }
720704
721- entries = []
705+ entries : set [ ComputedStructureEntry ] = set ()
722706
723707 fields = (
724708 ["entries" , "thermo_type" ]
@@ -733,24 +717,17 @@ def get_entries(
733717 )
734718
735719 for doc in docs :
736- entry_list = (
737- doc .entries .values () # type: ignore
738- if self .use_document_model
739- else doc ["entries" ].values () # type: ignore
740- )
720+ entry_list = doc ["entries" ].values ()
741721 for entry in entry_list :
742- entry_dict : dict = entry .as_dict () if self . monty_decode else entry # type: ignore
722+ entry_dict : dict = entry .as_dict () if hasattr ( entry , "as_dict" ) else entry # type: ignore
743723 if not compatible_only :
744724 entry_dict ["correction" ] = 0.0
745725 entry_dict ["energy_adjustments" ] = []
746726
747727 if property_data :
748- for property in property_data :
749- entry_dict ["data" ][property ] = (
750- doc .model_dump ()[property ] # type: ignore
751- if self .use_document_model
752- else doc [property ] # type: ignore
753- )
728+ entry_dict ["data" ] = {
729+ property : doc [property ] for property in property_data
730+ }
754731
755732 if conventional_unit_cell :
756733 entry_struct = Structure .from_dict (entry_dict ["structure" ])
@@ -771,15 +748,10 @@ def get_entries(
771748 if "n_atoms" in correction :
772749 correction ["n_atoms" ] *= site_ratio
773750
774- entry = (
775- ComputedStructureEntry .from_dict (entry_dict )
776- if self .monty_decode
777- else entry_dict
778- )
751+ # Need to store object to permit de-duplication
752+ entries .add (ComputedStructureEntry .from_dict (entry_dict ))
779753
780- entries .append (entry )
781-
782- return entries
754+ return [e if self .monty_decode else e .as_dict () for e in entries ]
783755
784756 def get_pourbaix_entries (
785757 self ,
@@ -1310,9 +1282,7 @@ def get_wulff_shape(self, material_id: str):
13101282 if not doc :
13111283 return None
13121284
1313- surfaces : list = (
1314- doc [0 ].surfaces if self .use_document_model else doc [0 ]["surfaces" ] # type: ignore
1315- )
1285+ surfaces : list = doc [0 ]["surfaces" ]
13161286
13171287 lattice = (
13181288 SpacegroupAnalyzer (structure ).get_conventional_standard_structure ().lattice
@@ -1382,17 +1352,8 @@ def get_charge_density_from_material_id(
13821352 if len (results ) == 0 :
13831353 return None
13841354
1385- latest_doc = max ( # type: ignore
1386- results ,
1387- key = lambda x : (
1388- x .last_updated # type: ignore
1389- if self .use_document_model
1390- else x ["last_updated" ]
1391- ), # type: ignore
1392- )
1393- task_id = (
1394- latest_doc .task_id if self .use_document_model else latest_doc ["task_id" ]
1395- )
1355+ latest_doc = max (results , key = lambda x : x ["last_updated" ])
1356+ task_id = latest_doc ["task_id" ]
13961357 return self .get_charge_density_from_task_id (task_id , inc_task_doc )
13971358
13981359 def get_download_info (self , material_ids , calc_types = None , file_patterns = None ):
@@ -1414,20 +1375,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
14141375 else []
14151376 )
14161377
1417- meta = {}
1378+ meta = defaultdict ( list )
14181379 for doc in self .materials .search ( # type: ignore
14191380 task_ids = material_ids ,
14201381 fields = ["calc_types" , "deprecated_tasks" , "material_id" ],
14211382 ):
1422- doc_dict : dict = doc .model_dump () if self .use_document_model else doc # type: ignore
1423- for task_id , calc_type in doc_dict ["calc_types" ].items ():
1383+ for task_id , calc_type in doc ["calc_types" ].items ():
14241384 if calc_types and calc_type not in calc_types :
14251385 continue
1426- mp_id = doc_dict ["material_id" ]
1427- if meta .get (mp_id ) is None :
1428- meta [mp_id ] = [{"task_id" : task_id , "calc_type" : calc_type }]
1429- else :
1430- meta [mp_id ].append ({"task_id" : task_id , "calc_type" : calc_type })
1386+ mp_id = doc ["material_id" ]
1387+ meta [mp_id ].append ({"task_id" : task_id , "calc_type" : calc_type })
1388+
14311389 if not meta :
14321390 raise ValueError (f"No tasks found for material id { material_ids } ." )
14331391
0 commit comments