Skip to content

Commit cc5ea02

Browse files
committed
Feat: pop None inputs specified in overrides
Fixes #653 Add the possibility of popping input namespaces by specifying None in the override for the specific namespace. A decorator is added that generalize the concept to any implementation of get_builder_from_protocol.
1 parent ae7d248 commit cc5ea02

File tree

5 files changed

+94
-0
lines changed

5 files changed

+94
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Decorators for several purposes."""
2+
3+
4+
def remove_none_overrides(func):
5+
def recursively_remove_nones(value):
6+
"""Recursively remove keys with None values from dictionaries."""
7+
if isinstance(value, dict):
8+
return {k: recursively_remove_nones(v) for k, v in value.items() if v is not None}
9+
return value
10+
11+
def remove_keys_from_builder(builder, keys, path=()):
12+
"""Recursively remove specified keys from the builder based on a path."""
13+
if not keys:
14+
return
15+
current_level = keys.pop(0)
16+
if hasattr(builder, current_level):
17+
if keys:
18+
next_attr = getattr(builder, current_level)
19+
remove_keys_from_builder(next_attr, keys, path + (current_level,))
20+
else:
21+
delattr(builder, current_level)
22+
23+
def wrapper(*args, **kwargs):
24+
if 'overrides' in kwargs and kwargs['overrides'] is not None:
25+
original_overrides = kwargs['overrides']
26+
27+
# Identify paths to keys with None values to be removed
28+
paths_to_remove = []
29+
def find_paths(value, path=()):
30+
if isinstance(value, dict):
31+
for k, v in value.items():
32+
if v is None:
33+
paths_to_remove.append(path + (k,))
34+
else:
35+
find_paths(v, path + (k,))
36+
find_paths(original_overrides)
37+
38+
# Recursively remove keys with None values from overrides
39+
cleaned_overrides = recursively_remove_nones(original_overrides)
40+
kwargs['overrides'] = cleaned_overrides
41+
42+
# Call the original function to get the builder
43+
builder = func(*args, **kwargs)
44+
45+
# Remove specified keys from the builder
46+
for path in paths_to_remove:
47+
remove_keys_from_builder(builder, list(path))
48+
49+
return builder
50+
else:
51+
return func(*args, **kwargs)
52+
53+
return wrapper

src/aiida_quantumespresso/workflows/pw/bands.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
99
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
1010
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
11+
from aiida_quantumespresso.utils.decorators import remove_none_overrides
1112

1213
from ..protocols.utils import ProtocolMixin
1314

@@ -120,6 +121,7 @@ def get_protocol_filepath(cls):
120121
return files(pw_protocols) / 'bands.yaml'
121122

122123
@classmethod
124+
@remove_none_overrides
123125
def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, options=None, **kwargs):
124126
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
125127

src/aiida_quantumespresso/workflows/pw/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import create_kpoints_from_distance
1010
from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType
1111
from aiida_quantumespresso.utils.defaults.calculation import pw as qe_defaults
12+
from aiida_quantumespresso.utils.decorators import remove_none_overrides
1213

1314
from ..protocols.utils import ProtocolMixin
1415

@@ -103,6 +104,7 @@ def get_protocol_filepath(cls):
103104
return files(pw_protocols) / 'base.yaml'
104105

105106
@classmethod
107+
@remove_none_overrides
106108
def get_builder_from_protocol(
107109
cls,
108110
code,

tests/workflows/protocols/pw/test_bands.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,29 @@ def test_options(fixture_code, generate_structure):
103103
builder.bands.pw.metadata, # pylint: disable=no-member
104104
):
105105
assert subspace['options']['queue_name'] == queue_name, subspace
106+
107+
108+
def test_pop_none_overrides(fixture_code, generate_structure):
109+
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
110+
code = fixture_code('quantumespresso.pw')
111+
structure = generate_structure()
112+
113+
overrides = {'relax': {'base_final_scf':None}}
114+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
115+
116+
assert 'base_final_scf' not in builder['relax'] # pylint: disable=no-member
117+
118+
overrides = {'relax': None}
119+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
120+
121+
assert 'relax' not in builder # pylint: disable=no-member
122+
123+
overrides = {'relax': {'base':{'pw':{'parameters':{'SYSTEM':{'ecutwfc': None}}}}}}
124+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
125+
126+
assert 'ecutwfc' in builder['relax']['base']['pw']['parameters']['SYSTEM'] # pylint: disable=no-member
127+
128+
overrides = {'relax': {'base':{'pw':{'parameters': None}}}}
129+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
130+
131+
assert 'parameters' not in builder['relax']['base']['pw'] # pylint: disable=no-member

tests/workflows/protocols/pw/test_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,14 @@ def test_options(fixture_code, generate_structure):
241241

242242
assert metadata['options']['queue_name'] == queue_name
243243
assert metadata['options']['withmpi'] == withmpi
244+
245+
246+
def test_pop_none_overrides(fixture_code, generate_structure):
247+
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
248+
code = fixture_code('quantumespresso.pw')
249+
structure = generate_structure()
250+
251+
overrides = {'kpoints_distance': None}
252+
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
253+
254+
assert 'kpoints_distance' not in builder # pylint: disable=no-member

0 commit comments

Comments
 (0)