@@ -697,9 +697,19 @@ def setup(self) -> None:
697697 self .ctx .current_folder = self .inputs ["pw2wannier90" ]["pw2wannier90" ][
698698 "parent_folder"
699699 ]
700- self .ctx .spin_collinear = (
701- self .inputs ["nscf" ]["pw" ]["parameters" ]["SYSTEM" ].get ("nspin" , 1 ) == 2
702- )
700+ # determine if we should use collinear spin.
701+ if self .should_run_nscf ():
702+ self .ctx .spin_collinear = (
703+ self .inputs ["nscf" ]["pw" ]["parameters" ]["SYSTEM" ].get ("nspin" , 1 )
704+ == 2
705+ )
706+ else :
707+ self .ctx .spin_collinear = False
708+ self .report (
709+ "Warning: "
710+ "currently can not determine whether the workflow uses collinear spin, "
711+ "because it skips scf and nscf steps."
712+ )
703713 else :
704714 self .ctx .spin_collinear = (
705715 self .inputs ["scf" ]["pw" ]["parameters" ]["SYSTEM" ].get ("nspin" , 1 ) == 2
@@ -1333,8 +1343,12 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements
13331343 if atom_proj and atom_proj_ext :
13341344 return
13351345
1336- # 1. the calculated number of projections is consistent with QE projwfc.x
1337- check_num_projs = True
1346+ # Determine whether to check number of projections or electrons
1347+ check_num_projs = False
1348+ check_num_elec = True
1349+ # check projections only if have calculated projwfc.x
1350+ if self .should_run_projwfc ():
1351+ check_num_projs = True
13381352 if self .should_run_scf ():
13391353 pseudos = self .inputs ["scf" ]["pw" ]["pseudos" ]
13401354 spin_orbit_coupling = (
@@ -1361,8 +1375,12 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements
13611375 )
13621376 else :
13631377 check_num_projs = False
1378+ check_num_elec = False
13641379 pseudos = None
1380+ spin_non_collinear = None
13651381 spin_orbit_coupling = None
1382+
1383+ # 1. the calculated number of projections is consistent with QE projwfc.x
13661384 if check_num_projs :
13671385 args = {
13681386 "structure" : self .ctx .current_structure ,
@@ -1372,16 +1390,13 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements
13721390 pseudos # pylint: disable=possibly-used-before-assignment
13731391 ),
13741392 }
1375- if "workchain_projwfc" in self .ctx :
1376- num_proj = len (
1377- self .ctx .workchain_projwfc .outputs ["projections" ].get_orbitals ()
1378- )
1379- params = self .ctx .workchain_wannier90 .inputs ["wannier90" ][
1380- "parameters"
1381- ].get_dict ()
1382- spin_orbit_coupling = params .get ("spinors" , False )
1393+ num_proj = len (
1394+ self .ctx .workchain_projwfc .outputs ["projections" ].get_orbitals ()
1395+ )
13831396 number_of_projections = get_number_of_projections (
1384- ** args , spin_non_collinear = spin_non_collinear , spin_orbit_coupling = spin_orbit_coupling
1397+ ** args ,
1398+ spin_non_collinear = spin_non_collinear ,
1399+ spin_orbit_coupling = spin_orbit_coupling ,
13851400 )
13861401 if number_of_projections != num_proj :
13871402 self .report (
@@ -1390,20 +1405,29 @@ def sanity_check(self): # pylint: disable=inconsistent-return-statements
13901405 return self .exit_codes .ERROR_SANITY_CHECK_FAILED
13911406
13921407 # 2. the number of electrons is consistent with QE output
1393- if "workchain_scf" in self .ctx :
1394- num_elec = self .ctx .workchain_scf .outputs ["output_parameters" ][
1395- "number_of_electrons"
1396- ]
1397- else :
1398- num_elec = self .ctx .workchain_nscf .outputs ["output_parameters" ][
1399- "number_of_electrons"
1400- ]
1401- number_of_electrons = get_number_of_electrons (** args )
1402- if number_of_electrons != num_elec :
1403- self .report (
1404- f"number of electrons { number_of_electrons } != QE output { num_elec } "
1405- )
1406- return self .exit_codes .ERROR_SANITY_CHECK_FAILED
1408+ if check_num_elec :
1409+ args = {
1410+ "structure" : self .ctx .current_structure ,
1411+ # The type of `self.inputs['scf']['pw']['pseudos']` is AttributesFrozendict,
1412+ # we need to convert it to dict, otherwise get_number_of_projections will fail.
1413+ "pseudos" : dict (
1414+ pseudos # pylint: disable=possibly-used-before-assignment
1415+ ),
1416+ }
1417+ if "workchain_scf" in self .ctx :
1418+ num_elec = self .ctx .workchain_scf .outputs ["output_parameters" ][
1419+ "number_of_electrons"
1420+ ]
1421+ else :
1422+ num_elec = self .ctx .workchain_nscf .outputs ["output_parameters" ][
1423+ "number_of_electrons"
1424+ ]
1425+ number_of_electrons = get_number_of_electrons (** args )
1426+ if number_of_electrons != num_elec :
1427+ self .report (
1428+ f"number of electrons { number_of_electrons } != QE output { num_elec } "
1429+ )
1430+ return self .exit_codes .ERROR_SANITY_CHECK_FAILED
14071431
14081432 # pylint: disable=inconsistent-return-statements,too-many-return-statements
14091433 def sanity_check_spin_collinear (
@@ -1443,81 +1467,50 @@ def sanity_check_spin_collinear(
14431467 return
14441468
14451469 # 1. the calculated number of projections is consistent with QE projwfc.x
1470+ check_num_projs = False
1471+ if self .should_run_projwfc ():
1472+ check_num_projs = True
14461473 if "scf" in self .inputs :
14471474 pseudos = self .inputs ["scf" ]["pw" ]["pseudos" ]
1448- else :
1475+ elif "nscf" in self . inputs :
14491476 pseudos = self .inputs ["nscf" ]["pw" ]["pseudos" ]
1450- args = {
1451- "structure" : self .ctx .current_structure ,
1452- # The type of `self.inputs['scf']['pw']['pseudos']` is AttributesFrozendict,
1453- # we need to convert it to dict, otherwise get_number_of_projections will fail.
1454- "pseudos" : dict (pseudos ),
1455- }
1456- if "workchain_projwfc" in self .ctx :
1457- if self .ctx .spin_collinear :
1458- num_proj_up = len (
1459- self .ctx .workchain_projwfc .outputs ["projections_up" ].get_orbitals ()
1460- )
1461- num_proj_down = len (
1462- self .ctx .workchain_projwfc .outputs [
1463- "projections_down"
1464- ].get_orbitals ()
1477+ else :
1478+ pseudos = None
1479+ check_num_projs = False
1480+ if check_num_projs :
1481+ args = {
1482+ "structure" : self .ctx .current_structure ,
1483+ # The type of `self.inputs['scf']['pw']['pseudos']` is AttributesFrozendict,
1484+ # we need to convert it to dict, otherwise get_number_of_projections will fail.
1485+ "pseudos" : dict (pseudos ),
1486+ }
1487+ num_proj_up = len (
1488+ self .ctx .workchain_projwfc .outputs ["projections_up" ].get_orbitals ()
1489+ )
1490+ num_proj_down = len (
1491+ self .ctx .workchain_projwfc .outputs ["projections_down" ].get_orbitals ()
1492+ )
1493+ if num_proj_up != num_proj_down :
1494+ self .report (
1495+ "number of projections in projwfc.x output "
1496+ + f"for spin up { num_proj_up } != spin down { num_proj_down } "
14651497 )
1466- if num_proj_up != num_proj_down :
1467- self .report (
1468- "number of projections in projwfc.x output "
1469- + f"for spin up { num_proj_up } != spin down { num_proj_down } "
1470- )
1471- return self .exit_codes .ERROR_SANITY_CHECK_FAILED
1498+ return self .exit_codes .ERROR_SANITY_CHECK_FAILED
14721499
1473- num_proj = num_proj_up
1474- else :
1475- num_proj = self .ctx .workchain_projwfc .outputs [
1476- "projections"
1477- ].get_orbitals ()
1478- if "workchain_wannier90_up" in self .ctx :
1479- params = self .ctx .workchain_wannier90_up .inputs ["wannier90" ][
1480- "parameters"
1481- ].get_dict ()
1482- spin_orbit_coupling = params .get ("spinors" , False )
1483- number_of_projections = get_number_of_projections (
1484- ** args , spin_orbit_coupling = spin_orbit_coupling
1485- )
1486- if number_of_projections != num_proj :
1487- self .report (
1488- f"number of projections { number_of_projections } != projwfc.x output { num_proj } "
1489- )
1490- return self .exit_codes .ERROR_SANITY_CHECK_FAILED
1491- if "workchain_wannier90_down" in self .ctx :
1492- params = self .ctx .workchain_wannier90_down .inputs ["wannier90" ][
1493- "parameters"
1494- ].get_dict ()
1495- spin_orbit_coupling = params .get ("spinors" , False )
1496- number_of_projections = get_number_of_projections (
1497- ** args , spin_orbit_coupling = spin_orbit_coupling
1498- )
1499- if number_of_projections != num_proj :
1500- self .report (
1501- f"number of projections { number_of_projections } != projwfc.x output { num_proj } "
1502- )
1503- return self .exit_codes .ERROR_SANITY_CHECK_FAILED
1504-
1505- if "workchain_wannier90" in self .ctx and not self .ctx .spin_collinear :
1506- params = self .ctx .workchain_wannier90 .inputs ["wannier90" ][
1507- "parameters"
1508- ].get_dict ()
1509- spin_orbit_coupling = params .get ("spinors" , False )
1510- number_of_projections = get_number_of_projections (
1511- ** args , spin_orbit_coupling = spin_orbit_coupling
1500+ num_proj = num_proj_up
1501+
1502+ number_of_projections = get_number_of_projections (
1503+ ** args , spin_non_collinear = False , spin_orbit_coupling = False
1504+ )
1505+ if number_of_projections != num_proj :
1506+ self .report (
1507+ f"number of projections { number_of_projections } != projwfc.x output { num_proj } "
15121508 )
1513- if number_of_projections != num_proj :
1514- self .report (
1515- f"number of projections { number_of_projections } != projwfc.x output { num_proj } "
1516- )
1517- return self .exit_codes .ERROR_SANITY_CHECK_FAILED
1509+ return self .exit_codes .ERROR_SANITY_CHECK_FAILED
15181510
15191511 # 2. the number of electrons is consistent with QE output
15201512 # only check num electrons when we already know pseudos in the check num projectors step
1513+ check_num_elec = check_num_projs
15211514 if "workchain_scf" in self .ctx :
15221515 num_elec = self .ctx .workchain_scf .outputs ["output_parameters" ][
15231516 "number_of_electrons"
@@ -1527,9 +1520,9 @@ def sanity_check_spin_collinear(
15271520 "number_of_electrons"
15281521 ]
15291522 else :
1530- check_num_elecs = False
1523+ check_num_elec = False
15311524 num_elec = None # to avoid pylint errors
1532- if check_num_elecs :
1525+ if check_num_elec :
15331526 number_of_electrons = get_number_of_electrons (** args )
15341527 if (
15351528 number_of_electrons
0 commit comments