Skip to content

Commit 005a7ad

Browse files
committed
auto-calculate ADDED_MOS based on basissets
1 parent 05d8199 commit 005a7ad

File tree

4 files changed

+152
-0
lines changed

4 files changed

+152
-0
lines changed

aiida_cp2k/calculations/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
validate_pseudos_namespace,
2323
write_basissets,
2424
write_pseudos,
25+
estimate_added_mos,
2526
)
2627
from ..utils import Cp2kInput
2728

@@ -167,6 +168,21 @@ def prepare_for_submission(self, folder):
167168
self.inputs.structure if 'structure' in self.inputs else None)
168169
write_basissets(inp, self.inputs.basissets, folder)
169170

171+
# if we have both basissets and structure we can start helping the user :)
172+
if 'basissets' in self.inputs and 'structure' in self.inputs:
173+
try:
174+
scf_section = inp.get_section_dict('FORCE_EVAL/DFT/SCF')
175+
176+
if 'SMEAR' in scf_section and 'ADDED_MOS' not in scf_section:
177+
# now is our time to shine!
178+
added_mos = estimate_added_mos(self.inputs.basissets, self.inputs.structure)
179+
inp.add_keyword('FORCE_EVAL/DFT/SCF/ADDED_MOS', added_mos)
180+
self.logger.info(f'The FORCE_EVAL/DFT/SCF/ADDED_MOS was added with an automatically estimated value'
181+
f' of {added_mos}')
182+
183+
except (KeyError, TypeError): # no SCF, no smearing, or multiple FORCE_EVAL, nothing to do (yet)
184+
pass
185+
170186
if 'pseudos' in self.inputs:
171187
validate_pseudos(inp, self.inputs.pseudos, self.inputs.structure if 'structure' in self.inputs else None)
172188
write_pseudos(inp, self.inputs.pseudos, folder)

aiida_cp2k/utils/datatype_helpers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,31 @@ def validate_basissets(inp, basissets, structure):
160160
kind_sec["ELEMENT"] = bset.element
161161

162162

163+
def estimate_added_mos(basissets, structure, fraction=0.3):
164+
"""Calculate an estimate for ADDED_MOS based on used basis sets"""
165+
166+
symbols = [structure.get_kind(s.kind_name).get_symbols_string() for s in structure.sites]
167+
n_mos = 0
168+
169+
# We are currently overcounting in the following cases:
170+
# * if we get a mix of ORB basissets for the same chemical symbol but different sites
171+
# * if we get multiple basissets for one element (merged within CP2K)
172+
173+
for label, bset in _unpack(basissets):
174+
try:
175+
_, bstype = label.split("_", maxsplit=1)
176+
except ValueError:
177+
bstype = "ORB"
178+
179+
if bstype != "ORB": # ignore non-ORB basissets
180+
continue
181+
182+
n_mos += symbols.count(bset.element) * bset.n_orbital_functions
183+
184+
# at least one additional MO per site, otherwise a fraction of the total number of orbital functions
185+
return max(len(symbols), int(fraction * n_mos))
186+
187+
163188
def write_basissets(inp, basissets, folder):
164189
"""Writes the unified BASIS_SETS file with the used basissets"""
165190
_write_gdt(inp, basissets, folder, "BASIS_SET_FILE_NAME", "BASIS_SETS")

aiida_cp2k/utils/input_generator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,37 @@ def add_keyword(self, kwpath, value, override=True, conflicting_keys=None):
5656

5757
Cp2kInput._add_keyword(kwpath, value, self._params, ovrd=override, cfct=conflicting_keys)
5858

59+
def get_section_dict(self, kwpath):
60+
return self._get_section_or_kw(kwpath, True)
61+
62+
def get_keyword_value(self, kwpath):
63+
return self._get_section_or_kw(kwpath, False)
64+
65+
def _get_section_or_kw(self, kwpath, section_requested):
66+
"""Retrieve either a section or a keyword given a path"""
67+
68+
if isinstance(kwpath, str): # turn to list of sections
69+
kwpath = kwpath.split("/")
70+
kwpath = [k.upper() for k in kwpath] # accept any case, but internally we use uppercase
71+
orig_kwpath = kwpath
72+
73+
current = self._params
74+
75+
try:
76+
while kwpath:
77+
current = current[kwpath.pop(0)]
78+
return current
79+
except KeyError:
80+
raise KeyError("Section '{}' not found in parameters".format("/".join(orig_kwpath)))
81+
82+
if isinstance(current, Mapping):
83+
if not section_requested:
84+
raise TypeError("Section '{}' requested, but keyword found".format("/".join(orig_kwpath)))
85+
elif section_requested:
86+
raise TypeError("Section '{}' requested, but keyword found".format("/".join(orig_kwpath)))
87+
88+
return deepcopy(current)
89+
5990
def render(self):
6091
output = [self.DISCLAIMER]
6192
self._render_section(output, deepcopy(self._params))

test/test_gaussian_datatypes.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,83 @@ def test_without_kinds(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database):
778778

779779
_, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs)
780780
assert calc_node.exit_status == 0
781+
782+
783+
def test_added_mos(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database): # pylint: disable=unused-argument
784+
"""Testing CP2K with the Basis Set stored in gaussian.basisset and a smearing section"""
785+
786+
# structure
787+
atoms = ase.build.molecule("H2O")
788+
atoms.center(vacuum=2.0)
789+
structure = StructureData(ase=atoms)
790+
791+
# parameters
792+
parameters = Dict(
793+
dict={
794+
'GLOBAL': {
795+
'RUN_TYPE': 'ENERGY',
796+
},
797+
'FORCE_EVAL': {
798+
'METHOD': 'Quickstep',
799+
'DFT': {
800+
"XC": {
801+
"XC_FUNCTIONAL": {
802+
"_": "PBE",
803+
},
804+
},
805+
"MGRID": {
806+
"CUTOFF": 1000.0,
807+
"REL_CUTOFF": 100.0,
808+
},
809+
"QS": {
810+
"METHOD": "GPW",
811+
"EXTRAPOLATION": "USE_GUESS",
812+
},
813+
"SCF": {
814+
"EPS_SCF": 1e-08,
815+
"MAX_SCF": 200,
816+
"MIXING": {
817+
"METHOD": "BROYDEN_MIXING",
818+
"ALPHA": 0.4,
819+
},
820+
"SMEAR": {
821+
"METHOD": "FERMI_DIRAC",
822+
"ELECTRONIC_TEMPERATURE": 300.0,
823+
},
824+
},
825+
"KPOINTS": {
826+
"SCHEME": "MONKHORST-PACK 4 4 2",
827+
"FULL_GRID": False,
828+
"SYMMETRY": False,
829+
"PARALLEL_GROUP_SIZE": -1,
830+
},
831+
},
832+
},
833+
})
834+
835+
options = {
836+
"resources": {
837+
"num_machines": 1,
838+
"num_mpiprocs_per_machine": 1
839+
},
840+
"max_wallclock_seconds": 1 * 3 * 60,
841+
}
842+
843+
inputs = {
844+
"structure": structure,
845+
"parameters": parameters,
846+
"code": cp2k_code,
847+
"metadata": {
848+
"options": options,
849+
},
850+
"basissets": cp2k_basissets,
851+
"pseudos": cp2k_pseudos,
852+
}
853+
854+
_, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs)
855+
856+
# check that the ADDED_MOS keyword was added within the calculation
857+
with calc_node.open("aiida.inp") as fhandle:
858+
assert any("ADDED_MOS" in line for line in fhandle), "ADDED_MOS not found in the generated CP2K input file"
859+
860+
assert calc_node.exit_status == 0

0 commit comments

Comments
 (0)