22
33import warnings
44from collections import defaultdict
5+ from typing import TYPE_CHECKING
56
67from emmet .core .electronic_structure import (
78 BSPathType ,
1516from mp_api .client .core import BaseRester , MPRestError
1617from mp_api .client .core .utils import validate_ids
1718
19+ if TYPE_CHECKING :
20+ from pymatgen .electronic_structure .dos import CompleteDos
21+
1822
1923class ElectronicStructureRester (BaseRester ):
2024 suffix = "materials/electronic_structure"
@@ -142,9 +146,28 @@ def search(
142146 )
143147
144148
145- class BandStructureRester (BaseRester ):
149+ class BaseESPropertyRester (BaseRester ):
150+ _es_rester : ElectronicStructureRester | None = None
151+ document_model = ElectronicStructureDoc
152+
153+ @property
154+ def es_rester (self ) -> ElectronicStructureRester :
155+ if not self ._es_rester :
156+ self ._es_rester = ElectronicStructureRester (
157+ api_key = self .api_key ,
158+ endpoint = self .base_endpoint ,
159+ include_user_agent = self .include_user_agent ,
160+ session = self .session ,
161+ monty_decode = self .monty_decode ,
162+ use_document_model = self .use_document_model ,
163+ headers = self .headers ,
164+ mute_progress_bars = self .mute_progress_bars ,
165+ )
166+ return self ._es_rester
167+
168+
169+ class BandStructureRester (BaseESPropertyRester ):
146170 suffix = "materials/electronic_structure/bandstructure"
147- document_model = ElectronicStructureDoc # type: ignore
148171
149172 def search_bandstructure_summary (self , * args , ** kwargs ): # pragma: no cover
150173 """Deprecated."""
@@ -258,19 +281,8 @@ def get_bandstructure_from_material_id(
258281 Returns:
259282 bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object
260283 """
261- es_rester = ElectronicStructureRester (
262- api_key = self .api_key ,
263- endpoint = self .base_endpoint ,
264- include_user_agent = self .include_user_agent ,
265- session = self .session ,
266- monty_decode = self .monty_decode ,
267- use_document_model = self .use_document_model ,
268- headers = self .headers ,
269- mute_progress_bars = self .mute_progress_bars ,
270- )
271-
272284 if line_mode :
273- bs_doc = es_rester .search (
285+ bs_doc = self . es_rester .search (
274286 material_ids = material_id , fields = ["bandstructure" ]
275287 )
276288 if not bs_doc :
@@ -293,7 +305,9 @@ def get_bandstructure_from_material_id(
293305
294306 else :
295307 if not (
296- bs_doc := es_rester .search (material_ids = material_id , fields = ["dos" ])
308+ bs_doc := self .es_rester .search (
309+ material_ids = material_id , fields = ["dos" ]
310+ )
297311 ):
298312 raise MPRestError ("No electronic structure data found." )
299313
@@ -319,9 +333,8 @@ def get_bandstructure_from_material_id(
319333 raise MPRestError ("No band structure object found." )
320334
321335
322- class DosRester (BaseRester ):
336+ class DosRester (BaseESPropertyRester ):
323337 suffix = "materials/electronic_structure/dos"
324- document_model = ElectronicStructureDoc # type: ignore
325338
326339 def search_dos_summary (self , * args , ** kwargs ): # pragma: no cover
327340 """Deprecated."""
@@ -403,7 +416,7 @@ def search(
403416 ** query_params ,
404417 )
405418
406- def get_dos_from_task_id (self , task_id : str ):
419+ def get_dos_from_task_id (self , task_id : str ) -> CompleteDos :
407420 """Get the density of states pymatgen object associated with a given calculation ID.
408421
409422 Arguments:
@@ -431,18 +444,9 @@ def get_dos_from_material_id(self, material_id: str):
431444 Returns:
432445 dos (CompleteDos): CompleteDos object
433446 """
434- es_rester = ElectronicStructureRester (
435- api_key = self .api_key ,
436- endpoint = self .base_endpoint ,
437- include_user_agent = self .include_user_agent ,
438- session = self .session ,
439- monty_decode = self .monty_decode ,
440- use_document_model = self .use_document_model ,
441- headers = self .headers ,
442- mute_progress_bars = self .mute_progress_bars ,
443- )
444-
445- if not (dos_doc := es_rester .search (material_ids = material_id , fields = ["dos" ])):
447+ if not (
448+ dos_doc := self .es_rester .search (material_ids = material_id , fields = ["dos" ])
449+ ):
446450 return None
447451
448452 if not (dos_data := dos_doc [0 ].get ("dos" )):
0 commit comments