Skip to content

Commit f6623b0

Browse files
committed
Feat: pop None inputs specified in overrides
Fixes #653 Add the possibility of popping input namespaces by specifying None in the override for the specific namespace. A decorator is added that generalize the concept to any implementation of get_builder_from_protocol.
1 parent ae7d248 commit f6623b0

File tree

5 files changed

+101
-0
lines changed

5 files changed

+101
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# -*- coding: utf-8 -*-
2+
"""Decorators for several purposes."""
3+
4+
5+
def remove_none_overrides(func):
6+
"""Remove namespaces of the returned builder of a `get_builder*` method."""
7+
8+
def recursively_remove_nones(item):
9+
"""Recursively remove keys with None values from dictionaries."""
10+
if isinstance(item, dict):
11+
return {key: recursively_remove_nones(value) for key, value in item.items() if value is not None}
12+
return item
13+
14+
def remove_keys_from_builder(builder, keys, path=()):
15+
"""Recursively remove specified keys from the builder based on a path."""
16+
if not keys:
17+
return
18+
current_level = keys.pop(0)
19+
if hasattr(builder, current_level):
20+
if keys:
21+
next_attr = getattr(builder, current_level)
22+
remove_keys_from_builder(next_attr, keys, path + (current_level,))
23+
else:
24+
delattr(builder, current_level)
25+
26+
def wrapper(*args, **kwargs):
27+
"""Wrap the function."""
28+
if 'overrides' in kwargs and kwargs['overrides'] is not None:
29+
original_overrides = kwargs['overrides']
30+
31+
# Identify paths to keys with None values to be removed
32+
paths_to_remove = []
33+
34+
def find_paths(item, path=()):
35+
"""Find the paths to remove."""
36+
if isinstance(item, dict):
37+
for key, value in item.items():
38+
if value is None:
39+
paths_to_remove.append(path + (key,))
40+
else:
41+
find_paths(value, path + (key,))
42+
43+
find_paths(original_overrides)
44+
45+
# Recursively remove keys with None values from overrides
46+
cleaned_overrides = recursively_remove_nones(original_overrides)
47+
kwargs['overrides'] = cleaned_overrides
48+
49+
# Call the original function to get the builder
50+
builder = func(*args, **kwargs)
51+
52+
# Remove specified keys from the builder
53+
for path in paths_to_remove:
54+
remove_keys_from_builder(builder, list(path))
55+
56+
return builder
57+
58+
return func(*args, **kwargs)
59+
60+
return wrapper

src/aiida_quantumespresso/workflows/pw/bands.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from aiida.engine import ToContext, WorkChain, if_
66

77
from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
8+
from aiida_quantumespresso.utils.decorators import remove_none_overrides
89
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
910
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
1011
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
@@ -120,6 +121,7 @@ def get_protocol_filepath(cls):
120121
return files(pw_protocols) / 'bands.yaml'
121122

122123
@classmethod
124+
@remove_none_overrides
123125
def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, options=None, **kwargs):
124126
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
125127

src/aiida_quantumespresso/workflows/pw/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import create_kpoints_from_distance
1010
from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType
11+
from aiida_quantumespresso.utils.decorators import remove_none_overrides
1112
from aiida_quantumespresso.utils.defaults.calculation import pw as qe_defaults
1213

1314
from ..protocols.utils import ProtocolMixin
@@ -103,6 +104,7 @@ def get_protocol_filepath(cls):
103104
return files(pw_protocols) / 'base.yaml'
104105

105106
@classmethod
107+
@remove_none_overrides
106108
def get_builder_from_protocol(
107109
cls,
108110
code,

tests/workflows/protocols/pw/test_bands.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,29 @@ def test_options(fixture_code, generate_structure):
103103
builder.bands.pw.metadata, # pylint: disable=no-member
104104
):
105105
assert subspace['options']['queue_name'] == queue_name, subspace
106+
107+
108+
def test_pop_none_overrides(fixture_code, generate_structure):
109+
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
110+
code = fixture_code('quantumespresso.pw')
111+
structure = generate_structure()
112+
113+
overrides = {'relax': {'base_final_scf': None}}
114+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
115+
116+
assert 'base_final_scf' not in builder['relax'] # pylint: disable=no-member
117+
118+
overrides = {'relax': None}
119+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
120+
121+
assert 'relax' not in builder # pylint: disable=no-member
122+
123+
overrides = {'relax': {'base': {'pw': {'parameters': {'SYSTEM': {'ecutwfc': None}}}}}}
124+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
125+
126+
assert 'ecutwfc' in builder['relax']['base']['pw']['parameters']['SYSTEM'] # pylint: disable=no-member
127+
128+
overrides = {'relax': {'base': {'pw': {'parameters': None}}}}
129+
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
130+
131+
assert 'parameters' not in builder['relax']['base']['pw'] # pylint: disable=no-member

tests/workflows/protocols/pw/test_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,14 @@ def test_options(fixture_code, generate_structure):
241241

242242
assert metadata['options']['queue_name'] == queue_name
243243
assert metadata['options']['withmpi'] == withmpi
244+
245+
246+
def test_pop_none_overrides(fixture_code, generate_structure):
247+
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
248+
code = fixture_code('quantumespresso.pw')
249+
structure = generate_structure()
250+
251+
overrides = {'kpoints_distance': None}
252+
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
253+
254+
assert 'kpoints_distance' not in builder # pylint: disable=no-member

0 commit comments

Comments
 (0)