@@ -28,21 +28,21 @@ def validate_inputs(inputs, _):
2828 """Validate the inputs before launching the WorkChain."""
2929 structure = inputs ['structure' ]
3030 elements_present = [kind .name for kind in structure .kinds ]
31- absorbing_elements_list = sorted (inputs ['elements_list' ])
3231 abs_atom_marker = inputs ['abs_atom_marker' ].value
3332 if abs_atom_marker in elements_present :
3433 raise ValidationError (
3534 f'The marker given for the absorbing atom ("{ abs_atom_marker } ") matches an existing Kind in the '
3635 f'input structure ({ elements_present } ).'
3736 )
38-
39- if inputs ['calc_binding_energy' ].value :
40- ce_list = sorted (inputs ['correction_energies' ].get_dict ().keys ())
41- if ce_list != absorbing_elements_list :
42- raise ValidationError (
43- f'The ``correction_energies`` provided ({ ce_list } ) does not match the list of'
44- f' absorbing elements ({ absorbing_elements_list } )'
45- )
37+ if 'elements_list' in inputs :
38+ absorbing_elements_list = sorted (inputs ['elements_list' ])
39+ if inputs ['calc_binding_energy' ].value :
40+ ce_list = sorted (inputs ['correction_energies' ].get_dict ().keys ())
41+ if ce_list != absorbing_elements_list :
42+ raise ValidationError (
43+ f'The ``correction_energies`` provided ({ ce_list } ) does not match the list of'
44+ f' absorbing elements ({ absorbing_elements_list } )'
45+ )
4646
4747
4848class XpsWorkChain (ProtocolMixin , WorkChain ):
@@ -81,7 +81,7 @@ def define(cls, spec):
8181 spec .expose_inputs (
8282 PwBaseWorkChain ,
8383 namespace = 'ch_scf' ,
84- exclude = ('kpoints' , ' pw.structure' ),
84+ exclude = ('pw.structure' , ),
8585 namespace_options = {
8686 'help' : ('Input parameters for the basic xps workflow (core-hole SCF).' ),
8787 'validator' : None
@@ -170,6 +170,14 @@ def define(cls, spec):
170170 'The list of elements to be considered for analysis, each must be valid elements of the periodic table.'
171171 )
172172 )
173+ spec .input (
174+ 'atoms_list' ,
175+ valid_type = orm .List ,
176+ required = False ,
177+ help = (
178+ 'The indices of atoms to be considered for analysis.'
179+ )
180+ )
173181 spec .input (
174182 'calc_binding_energy' ,
175183 valid_type = orm .Bool ,
@@ -233,12 +241,14 @@ def define(cls, spec):
233241 spec .output (
234242 'supercell_structure' ,
235243 valid_type = orm .StructureData ,
244+ required = False ,
236245 help = ('The supercell of ``outputs.standardized_structure`` used to generate structures for'
237246 ' XPS sub-processes.' )
238247 )
239248 spec .output (
240249 'symmetry_analysis_data' ,
241250 valid_type = orm .Dict ,
251+ required = False ,
242252 help = 'The output parameters from ``get_xspectra_structures()``.'
243253 )
244254 spec .output (
@@ -366,8 +376,8 @@ def get_treatment_filepath(cls):
366376 @classmethod
367377 def get_builder_from_protocol (
368378 cls , code , structure , pseudos , core_hole_treatments = None , protocol = None ,
369- overrides = None , elements_list = None , options = None ,
370- structure_preparation_settings = None , ** kwargs
379+ overrides = None , elements_list = None , atoms_list = None , options = None ,
380+ structure_preparation_settings = None , correction_energies = None , ** kwargs
371381 ):
372382 """Return a builder prepopulated with inputs selected according to the chosen protocol.
373383
@@ -386,9 +396,6 @@ def get_builder_from_protocol(
386396 """
387397
388398 inputs = cls .get_protocol_inputs (protocol , overrides )
389- calc_binding_energy = kwargs .pop ('calc_binding_energy' , False )
390- correction_energies = kwargs .pop ('correction_energies' , orm .Dict ())
391-
392399 pw_args = (code , structure , protocol )
393400 # xspectra_args = (pw_code, xs_code, structure, protocol, upf2plotcore_code)
394401
@@ -412,8 +419,11 @@ def get_builder_from_protocol(
412419 builder .ch_scf = ch_scf
413420 builder .structure = structure
414421 builder .abs_atom_marker = abs_atom_marker
415- builder .calc_binding_energy = calc_binding_energy
416- builder .correction_energies = correction_energies
422+ if correction_energies :
423+ builder .correction_energies = orm .Dict (correction_energies )
424+ builder .calc_binding_energy = orm .Bool (True )
425+ else :
426+ builder .calc_binding_energy = orm .Bool (False )
417427 builder .clean_workdir = orm .Bool (inputs ['clean_workdir' ])
418428 core_hole_pseudos = {}
419429 gipaw_pseudos = {}
@@ -434,6 +444,12 @@ def get_builder_from_protocol(
434444 for element in elements_list :
435445 core_hole_pseudos [element ] = pseudos [element ]['core_hole' ]
436446 gipaw_pseudos [element ] = pseudos [element ]['gipaw' ]
447+ elif atoms_list :
448+ builder .atoms_list = orm .List (atoms_list )
449+ for index in atoms_list :
450+ element = structure .sites [index ].kind_name
451+ core_hole_pseudos [element ] = pseudos [element ]['core_hole' ]
452+ gipaw_pseudos [element ] = pseudos [element ]['gipaw' ]
437453 # if no elements list is given, we instead initalise the pseudos dict with all
438454 # elements in the structure
439455 else :
@@ -453,12 +469,18 @@ def get_builder_from_protocol(
453469
454470 def setup (self ):
455471 """Init required context variables."""
456- custom_elements_list = self .inputs .get ('elements_list' , None )
457- if not custom_elements_list :
472+ elements_list = self .inputs .get ('elements_list' , None )
473+ atoms_list = self .inputs .get ('atoms_list' , None )
474+ if elements_list :
475+ self .ctx .elements_list = elements_list .get_list ()
476+ self .ctx .atoms_list = None
477+ elif atoms_list :
478+ self .ctx .atoms_list = atoms_list .get_list ()
479+ self .ctx .elements_list = None
480+ else :
458481 structure = self .inputs .structure
459482 self .ctx .elements_list = [Kind .symbol for Kind in structure .kinds ]
460- else :
461- self .ctx .elements_list = custom_elements_list .get_list ()
483+
462484
463485
464486 def should_run_relax (self ):
@@ -511,48 +533,59 @@ def prepare_structures(self):
511533 formatted as {<variable_name> : <parameter>} for each variable in the
512534 ``get_symmetry_dataset()`` method.
513535 """
536+ from aiida_quantumespresso .workflows .functions .get_marked_structures import get_marked_structures
514537 from aiida_quantumespresso .workflows .functions .get_xspectra_structures import get_xspectra_structures
515538
516- elements_list = orm .List (self .ctx .elements_list )
517- inputs = {
518- 'absorbing_elements_list' : elements_list ,
519- 'absorbing_atom_marker' : self .inputs .abs_atom_marker ,
520- 'metadata' : {
521- 'call_link_label' : 'get_xspectra_structures'
539+ input_structure = self .inputs .structure if 'relax' not in self .inputs else self .ctx .relaxed_structure
540+ if self .ctx .elements_list :
541+ elements_list = orm .List (self .ctx .elements_list )
542+ inputs = {
543+ 'absorbing_elements_list' : elements_list ,
544+ 'absorbing_atom_marker' : self .inputs .abs_atom_marker ,
545+ 'metadata' : {
546+ 'call_link_label' : 'get_xspectra_structures'
547+ }
548+ } # populate this further once the schema for WorkChain options is figured out
549+ if 'structure_preparation_settings' in self .inputs :
550+ optional_cell_prep = self .inputs .structure_preparation_settings
551+ for key , node in optional_cell_prep .items ():
552+ inputs [key ] = node
553+ if 'spglib_settings' in self .inputs :
554+ spglib_settings = self .inputs .spglib_settings
555+ inputs ['spglib_settings' ] = spglib_settings
556+ else :
557+ spglib_settings = None
558+
559+ result = get_xspectra_structures (input_structure , ** inputs )
560+
561+ supercell = result .pop ('supercell' )
562+ out_params = result .pop ('output_parameters' )
563+ if out_params .get_dict ().get ('structure_is_standardized' , None ):
564+ standardized = result .pop ('standardized_structure' )
565+ self .out ('standardized_structure' , standardized )
566+
567+ # structures_to_process = {Key : Value for Key, Value in result.items()}
568+ for site in ['output_parameters' , 'supercell' , 'standardized_structure' ]:
569+ result .pop (site , None )
570+ self .ctx .supercell = supercell
571+ self .ctx .equivalent_sites_data = out_params ['equivalent_sites_data' ]
572+ self .out ('supercell_structure' , supercell )
573+ self .out ('symmetry_analysis_data' , out_params )
574+ elif self .ctx .atoms_list :
575+ atoms_list = orm .List (self .ctx .atoms_list )
576+ inputs = {
577+ 'atoms_list' : atoms_list ,
578+ 'marker' : self .inputs .abs_atom_marker ,
579+ 'metadata' : {
580+ 'call_link_label' : 'get_marked_structures'
581+ }
522582 }
523- } # populate this further once the schema for WorkChain options is figured out
524- if 'structure_preparation_settings' in self .inputs :
525- optional_cell_prep = self .inputs .structure_preparation_settings
526- for key , node in optional_cell_prep .items ():
527- inputs [key ] = node
528- if 'spglib_settings' in self .inputs :
529- spglib_settings = self .inputs .spglib_settings
530- inputs ['spglib_settings' ] = spglib_settings
531- else :
532- spglib_settings = None
533-
534- if 'relax' in self .inputs :
535- relaxed_structure = self .ctx .relaxed_structure
536- result = get_xspectra_structures (relaxed_structure , ** inputs )
537- else :
538- result = get_xspectra_structures (self .inputs .structure , ** inputs )
539-
540- supercell = result .pop ('supercell' )
541- out_params = result .pop ('output_parameters' )
542- if out_params .get_dict ().get ('structure_is_standardized' , None ):
543- standardized = result .pop ('standardized_structure' )
544- self .out ('standardized_structure' , standardized )
545-
546- # structures_to_process = {Key : Value for Key, Value in result.items()}
547- for site in ['output_parameters' , 'supercell' , 'standardized_structure' ]:
548- result .pop (site , None )
583+ result = get_marked_structures (input_structure , ** inputs )
584+ self .ctx .supercell = input_structure
585+ self .ctx .equivalent_sites_data = result .pop ('output_parameters' ).get_dict ()
549586 structures_to_process = {f'{ Key .split ("_" )[0 ]} _{ Key .split ("_" )[1 ]} ' : Value for Key , Value in result .items ()}
550- self .ctx . supercell = supercell
587+ self .report ( f'structures_to_process: { structures_to_process } ' )
551588 self .ctx .structures_to_process = structures_to_process
552- self .ctx .equivalent_sites_data = out_params ['equivalent_sites_data' ]
553-
554- self .out ('supercell_structure' , supercell )
555- self .out ('symmetry_analysis_data' , out_params )
556589
557590 def should_run_gs_scf (self ):
558591 """If the 'calc_binding_energy' input namespace is True, we run a scf calculation for the supercell."""
@@ -566,9 +599,9 @@ def run_gs_scf(self):
566599 inputs .metadata .call_link_label = 'supercell_xps'
567600
568601 inputs = prepare_process_inputs (PwBaseWorkChain , inputs )
569- equivalent_sites_data = self . ctx . equivalent_sites_data
570- for site in equivalent_sites_data :
571- abs_element = equivalent_sites_data [site ]['symbol' ]
602+ # pseudos for all elements to be calculated should be replaced
603+ for site in self . ctx . equivalent_sites_data :
604+ abs_element = self . ctx . equivalent_sites_data [site ]['symbol' ]
572605 inputs .pw .pseudos [abs_element ] = self .inputs .gipaw_pseudos [abs_element ]
573606 running = self .submit (PwBaseWorkChain , ** inputs )
574607
@@ -600,7 +633,6 @@ def run_all_scf(self):
600633 equivalent_sites_data = self .ctx .equivalent_sites_data
601634 abs_atom_marker = self .inputs .abs_atom_marker .value
602635
603-
604636 for site in structures_to_process :
605637 inputs = AttributeDict (self .exposed_inputs (PwBaseWorkChain , namespace = 'ch_scf' ))
606638 structure = structures_to_process [site ]
@@ -630,9 +662,10 @@ def run_all_scf(self):
630662
631663 core_hole_pseudo = self .inputs .core_hole_pseudos [abs_element ]
632664 inputs .pw .pseudos [abs_atom_marker ] = core_hole_pseudo
633- # all element in the elements_list should be replaced
634- for element in self .inputs .elements_list :
635- inputs .pw .pseudos [element ] = self .inputs .gipaw_pseudos [element ]
665+ # pseudos for all elements to be calculated should be replaced
666+ for key in self .ctx .equivalent_sites_data :
667+ abs_element = self .ctx .equivalent_sites_data [key ]['symbol' ]
668+ inputs .pw .pseudos [abs_element ] = self .inputs .gipaw_pseudos [abs_element ]
636669 # remove pseudo if the only element is replaced by the marker
637670 inputs .pw .pseudos = {kind .name : inputs .pw .pseudos [kind .name ] for kind in structure .kinds }
638671
@@ -674,11 +707,15 @@ def results(self):
674707 kwargs ['correction_energies' ] = self .inputs .correction_energies
675708 kwargs ['metadata' ] = {'call_link_label' : 'compile_final_spectra' }
676709
677- equivalent_sites_data = orm .Dict (dict = self .ctx .equivalent_sites_data )
678- elements_list = orm .List (list = self .ctx .elements_list )
710+ if self .ctx .elements_list :
711+ elements_list = orm .List (list = self .ctx .elements_list )
712+ else :
713+ symbols = {value ['symbol' ] for value in self .ctx .equivalent_sites_data .values ()}
714+ elements_list = orm .List (list (symbols ))
679715 voight_gamma = self .inputs .voight_gamma
680716 voight_sigma = self .inputs .voight_sigma
681717
718+ equivalent_sites_data = orm .Dict (dict = self .ctx .equivalent_sites_data )
682719 result = get_spectra_by_element (
683720 elements_list ,
684721 equivalent_sites_data ,
0 commit comments