@@ -75,7 +75,7 @@ def search(
7575 num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
7676 chunk_size (int): Number of data entries per chunk.
7777 all_fields (bool): Whether to return all fields in the document. Defaults to True.
78- fields (List[str]): List of fields in EOSDoc to return data for.
78+ fields (List[str]): List of fields in ElectronicStructureDoc to return data for.
7979 Default is material_id and last_updated if all_fields is False.
8080
8181 Returns:
@@ -186,8 +186,8 @@ def search(
186186 efermi : tuple [float , float ] | None = None ,
187187 is_gap_direct : bool | None = None ,
188188 is_metal : bool | None = None ,
189- magnetic_ordering : Ordering | None = None ,
190- path_type : BSPathType = BSPathType .setyawan_curtarolo ,
189+ magnetic_ordering : Ordering | str | None = None ,
190+ path_type : BSPathType | str = BSPathType .setyawan_curtarolo ,
191191 num_chunks : int | None = None ,
192192 chunk_size : int = 1000 ,
193193 all_fields : bool = True ,
@@ -200,8 +200,8 @@ def search(
200200 efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider.
201201 is_gap_direct (bool): Whether the material has a direct band gap.
202202 is_metal (bool): Whether the material is considered a metal.
203- magnetic_ordering (Ordering): Magnetic ordering of the material.
204- path_type (BSPathType): k-path selection convention for the band structure.
203+ magnetic_ordering (Ordering or str ): Magnetic ordering of the material.
204+ path_type (BSPathType or str ): k-path selection convention for the band structure.
205205 num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
206206 chunk_size (int): Number of data entries per chunk.
207207 all_fields (bool): Whether to return all fields in the document. Defaults to True.
@@ -213,7 +213,9 @@ def search(
213213 """
214214 query_params = defaultdict (dict ) # type: dict
215215
216- query_params ["path_type" ] = path_type .value
216+ query_params ["path_type" ] = (
217+ BSPathType [path_type ] if isinstance (path_type , str ) else path_type
218+ ).value
217219
218220 if band_gap :
219221 query_params .update (
@@ -224,7 +226,15 @@ def search(
224226 query_params .update ({"efermi_min" : efermi [0 ], "efermi_max" : efermi [1 ]})
225227
226228 if magnetic_ordering :
227- query_params .update ({"magnetic_ordering" : magnetic_ordering .value })
229+ query_params .update (
230+ {
231+ "magnetic_ordering" : (
232+ Ordering (magnetic_ordering )
233+ if isinstance (magnetic_ordering , str )
234+ else magnetic_ordering
235+ ).value
236+ }
237+ )
228238
229239 if is_gap_direct is not None :
230240 query_params .update ({"is_gap_direct" : is_gap_direct })
@@ -351,11 +361,11 @@ def search(
351361 self ,
352362 band_gap : tuple [float , float ] | None = None ,
353363 efermi : tuple [float , float ] | None = None ,
354- element : Element | None = None ,
355- magnetic_ordering : Ordering | None = None ,
356- orbital : OrbitalType | None = None ,
357- projection_type : DOSProjectionType = DOSProjectionType .total ,
358- spin : Spin = Spin .up ,
364+ element : Element | str | None = None ,
365+ magnetic_ordering : Ordering | str | None = None ,
366+ orbital : OrbitalType | str | None = None ,
367+ projection_type : DOSProjectionType | str = DOSProjectionType .total ,
368+ spin : Spin | str = Spin .up ,
359369 num_chunks : int | None = None ,
360370 chunk_size : int = 1000 ,
361371 all_fields : bool = True ,
@@ -366,30 +376,54 @@ def search(
366376 Arguments:
367377 band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider.
368378 efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider.
369- element (Element): Element for element-projected dos data.
370- magnetic_ordering (Ordering): Magnetic ordering of the material.
371- orbital (OrbitalType): Orbital for orbital-projected dos data.
372- projection_type (DOSProjectionType): Projection type of dos data. Default is the total dos.
373- spin (Spin): Spin channel of dos data. If non spin-polarized data is stored in Spin.up
379+ element (Element or str ): Element for element-projected dos data.
380+ magnetic_ordering (Ordering or str ): Magnetic ordering of the material.
381+ orbital (OrbitalType or str ): Orbital for orbital-projected dos data.
382+ projection_type (DOSProjectionType or str ): Projection type of dos data. Default is the total dos.
383+ spin (Spin or str ): Spin channel of dos data. If non spin-polarized data is stored in Spin.up
374384 num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
375385 chunk_size (int): Number of data entries per chunk.
376386 all_fields (bool): Whether to return all fields in the document. Defaults to True.
377- fields (List[str]): List of fields in EOSDoc to return data for.
387+ fields (List[str]): List of fields in ElectronicStructureDoc to return data for.
378388 Default is material_id and last_updated if all_fields is False.
379389
380390 Returns:
381391 ([ElectronicStructureDoc]) List of electronic structure documents
382392 """
383393 query_params = defaultdict (dict ) # type: dict
384394
385- query_params ["projection_type" ] = projection_type .value
386- query_params ["spin" ] = spin .value
395+ query_params ["projection_type" ] = (
396+ DOSProjectionType [projection_type ]
397+ if isinstance (projection_type , str )
398+ else projection_type
399+ ).value
400+ query_params ["spin" ] = (Spin [spin ] if isinstance (spin , str ) else spin ).value
401+
402+ if (
403+ query_params ["projection_type" ] == DOSProjectionType .elemental .value
404+ and element is None
405+ ):
406+ raise MPRestError (
407+ "To query element-projected DOS, you must also specify the `element` onto which the DOS is projected."
408+ )
409+
410+ if (
411+ query_params ["projection_type" ] == DOSProjectionType .orbital .value
412+ and orbital is None
413+ ):
414+ raise MPRestError (
415+ "To query orbital-projected DOS, you must also specify the `orbital` character onto which the DOS is projected."
416+ )
387417
388418 if element :
389- query_params ["element" ] = element .value
419+ query_params ["element" ] = (
420+ Element [element ] if isinstance (element , str ) else element
421+ ).value
390422
391423 if orbital :
392- query_params ["orbital" ] = orbital .value
424+ query_params ["orbital" ] = (
425+ OrbitalType [orbital ] if isinstance (orbital , str ) else orbital
426+ ).value
393427
394428 if band_gap :
395429 query_params .update (
@@ -400,7 +434,15 @@ def search(
400434 query_params .update ({"efermi_min" : efermi [0 ], "efermi_max" : efermi [1 ]})
401435
402436 if magnetic_ordering :
403- query_params .update ({"magnetic_ordering" : magnetic_ordering .value })
437+ query_params .update (
438+ {
439+ "magnetic_ordering" : (
440+ Ordering [magnetic_ordering ]
441+ if isinstance (magnetic_ordering , str )
442+ else magnetic_ordering
443+ ).value
444+ }
445+ )
404446
405447 query_params = {
406448 entry : query_params [entry ]
0 commit comments