Skip to content

Commit 4ef0a23

Browse files
authored
Merge pull request #623 from jmmshn/dynamic_poweups
better handling of dynamic powerups
2 parents 68b1c01 + 8ea2b0f commit 4ef0a23

File tree

7 files changed

+125
-58
lines changed

7 files changed

+125
-58
lines changed

atomate/common/powerups.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
"""
44
This module defines general powerups that can be used for all workflows
55
"""
6+
from importlib import import_module
7+
from typing import List
68

79
from atomate.utils.utils import get_fws_and_tasks
810
from fireworks import Workflow, FileWriteTask
911
from fireworks.utilities.fw_utilities import get_slug
1012

11-
1213
__author__ = "Janine George, Guido Petretto, Ryan Kingsbury"
1314
__email__ = (
1415
@@ -219,3 +220,60 @@ def set_queue_adapter(
219220
original_wf.fws[idx_fw].spec["_queueadapter"] = q
220221

221222
return original_wf
223+
224+
225+
def powerup_by_kwargs(
226+
original_wf: Workflow,
227+
powerup_dicts: List[dict],
228+
):
229+
"""
230+
apply powerups in the form using a list of dictionaries
231+
[
232+
{"powerup_name" : powerup_function1, "kwargs": {parameter1 : value1, parameter2: value2}},
233+
{"powerup_name" : powerup_function2, "kwargs": {parameter1 : value1, parameter2: value2}},
234+
]
235+
236+
As an example:
237+
power_up_by_kwargs([
238+
{"powerup_name" : "add_additional_fields_to_taskdocs",
239+
"kwargs: {"update_dict" : {"foo" : "bar"}}}
240+
]
241+
)
242+
243+
Args:
244+
original_wf: workflow that will be changed
245+
powerup_dicts: dictionary containing the powerup_name and kwarg.
246+
if "." is present in the name it will be imported as a full path
247+
if not we will use standard atomate modules where the powerups are kept
248+
249+
"""
250+
# a list of possible powerups in atomate (most specific first)
251+
powerup_modules = [
252+
"atomate.vasp.powerups",
253+
"atomate.qchem.powerups",
254+
"atomate.common.powerups",
255+
]
256+
257+
for pd in powerup_dicts:
258+
name = pd["powerup_name"]
259+
kwargs = pd["kwargs"]
260+
found = False
261+
if "." in name:
262+
module_name, method_name = name.rsplit(".", 1)
263+
module = import_module(module_name)
264+
powerup = getattr(module, method_name)
265+
original_wf = powerup(original_wf, **kwargs)
266+
found = True
267+
else:
268+
for module_name in powerup_modules:
269+
try:
270+
module = import_module(module_name)
271+
powerup = getattr(module, name)
272+
original_wf = powerup(original_wf, **kwargs)
273+
found = True
274+
break
275+
except Exception:
276+
pass
277+
if not found:
278+
raise RuntimeError("Could not find powerup {}.".format(name))
279+
return original_wf

atomate/common/tests/test_powerups.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from pymatgen.io.vasp.sets import MPRelaxSet
66
from pymatgen.util.testing import PymatgenTest
77

8-
from atomate.common.powerups import set_queue_adapter, add_priority, add_tags
8+
from atomate.common.powerups import (
9+
set_queue_adapter,
10+
add_priority,
11+
add_tags,
12+
powerup_by_kwargs,
13+
)
914
from fireworks import Firework, ScriptTask, Workflow
1015

1116
__author__ = "Janine George, Guido Petretto"
@@ -76,6 +81,20 @@ def test_add_tags(self):
7681
self.assertEqual(b_found, 1)
7782
self.assertEqual(v_found, 4)
7883

84+
def test_powerup_by_kwargs(self):
85+
my_wf = copy_wf(self.bs_wf)
86+
my_wf = powerup_by_kwargs(
87+
my_wf,
88+
[
89+
{"powerup_name": "add_tags", "kwargs": {"tags_list": ["foo", "bar"]}},
90+
{
91+
"powerup_name": "atomate.common.powerups.add_priority",
92+
"kwargs": {"root_priority": 123},
93+
},
94+
],
95+
)
96+
self.assertEqual(my_wf.metadata["tags"], ["foo", "bar"])
97+
7998
def test_set_queue_adapter(self):
8099
# test fw_name_constraint
81100
fw1 = Firework([ScriptTask(script=None)], fw_id=-1, name="Firsttask")

atomate/vasp/firetasks/electrode_tasks.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from collections import defaultdict
32

43
from fireworks import FiretaskBase, explicit_serialize, FWAction, Firework, Workflow
54
from pymatgen.core import Structure
@@ -13,7 +12,7 @@
1312

1413
from atomate.vasp.database import VaspCalcDb
1514
from atomate.vasp.firetasks import pass_vasp_result
16-
from atomate.vasp.powerups import powerup_by_kwargs, POWERUP_NAMES
15+
from atomate.common.powerups import powerup_by_kwargs
1716

1817
__author__ = "Jimmy Shen"
1918
__email__ = "[email protected]"
@@ -316,11 +315,13 @@ def get_powerup_wf(wf, fw_spec, additional_fields=None):
316315
Returns:
317316
Updated workflow
318317
"""
319-
d_pu = defaultdict(dict)
320-
d_pu.update(fw_spec.get("vasp_powerups", {}))
318+
powerup_list = []
319+
powerup_list.extend(fw_spec.get("vasp_powerups", []))
321320
if additional_fields is not None:
322-
d_pu["add_additional_fields_to_taskdocs"].update(
323-
{"update_dict": additional_fields}
321+
powerup_list.append(
322+
{
323+
"powerup_name": "add_additional_fields_to_taskdocs",
324+
"kwargs": {"update_dict": additional_fields},
325+
}
324326
)
325-
p_kwargs = {k: d_pu[k] for k in POWERUP_NAMES if k in d_pu}
326-
return powerup_by_kwargs(wf, **p_kwargs)
327+
return powerup_by_kwargs(wf, powerup_list)

atomate/vasp/powerups.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
__author__ = "Anubhav Jain, Kiran Mathew, Alex Ganose"
3434
3535

36-
POWERUP_NAMES = []
37-
3836

3937
@deprecated(replacement=common_add_priority)
4038
def add_priority(original_wf, root_priority, child_priority=None):
@@ -829,35 +827,3 @@ def use_fake_lobster(original_wf, ref_dirs, params_to_check=None):
829827
)
830828

831829
return original_wf
832-
833-
834-
local_names = dict(locals()) # locals() changes as the program runs, make a copy here
835-
836-
for k, v in local_names.items():
837-
if (
838-
hasattr(v, "__module__")
839-
and v.__module__ == "atomate.vasp.powerups"
840-
and k != "power_up_by_kwargs"
841-
):
842-
POWERUP_NAMES.append(k)
843-
844-
local_names = {k: v for k, v in local_names.items() if k in POWERUP_NAMES}
845-
846-
847-
def powerup_by_kwargs(wf, **kwargs):
848-
"""
849-
apply powerups in the form using a kwargs dictionary of the form:
850-
{
851-
powerup_function_name1 : {parameter1 : value1, parameter2: value2},
852-
powerup_function_name2 : {parameter1 : value1, parameter2: value2},
853-
}
854-
855-
As an example:
856-
power_up_by_kwargs( "add_additional_fields_to_taskdocs" : {
857-
"update_dict" : {"foo" : "bar"}
858-
}
859-
)
860-
"""
861-
for k, v in kwargs.items():
862-
wf = local_names[k](wf, **v)
863-
return wf

atomate/vasp/tests/test_vasp_powerups.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
set_queue_options,
2020
use_potcar_spec,
2121
)
22+
from atomate.common.powerups import powerup_by_kwargs
2223
from atomate.vasp.workflows.base.core import get_wf
2324

2425
from pymatgen.io.vasp.sets import MPRelaxSet
@@ -244,6 +245,19 @@ def test_use_potcar_spec(self):
244245
task = wf.fws[idx_fw].tasks[idx_t]
245246
self.assertTrue(task["potcar_spec"])
246247

248+
def test_powerup_by_kwargs(self):
249+
my_wf = copy_wf(self.bs_wf)
250+
my_wf = powerup_by_kwargs(
251+
my_wf, [{"powerup_name": "add_trackers", "kwargs": {}}]
252+
)
253+
my_wf = powerup_by_kwargs(
254+
my_wf,
255+
[{"powerup_name": "add_tags", "kwargs": {"tags_list": ["foo", "bar"]}}],
256+
)
257+
for fw in my_wf.fws:
258+
self.assertEqual(len(fw.spec["_trackers"]), 2)
259+
self.assertEqual(my_wf.metadata["tags"], ["foo", "bar"])
260+
247261

248262
def copy_wf(wf):
249263
return Workflow.from_dict(wf.to_dict())

atomate/vasp/workflows/base/electrode.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List
2+
13
from fireworks import Workflow
24
from pymatgen.core import Structure
35
from pymatgen.analysis.structure_matcher import StructureMatcher
@@ -9,7 +11,7 @@
911
__email__ = "[email protected]"
1012

1113
from atomate.vasp.fireworks import Firework, OptimizeFW, StaticFW, pass_vasp_result
12-
from atomate.vasp.powerups import powerup_by_kwargs
14+
from atomate.common.powerups import powerup_by_kwargs
1315

1416
"""
1517
Define workflow related to battery material simulation --- they all have a working ion
@@ -23,7 +25,7 @@ def get_ion_insertion_wf(
2325
db_file: str = DB_FILE,
2426
vasptodb_kwargs: dict = None,
2527
volumetric_data_type: str = "CHGCAR",
26-
vasp_powerups: dict = None,
28+
vasp_powerups: List[dict] = None,
2729
max_insertions: int = 4,
2830
allow_fizzled_parents: bool = True,
2931
optimizefw_kwargs: dict = None,
@@ -106,7 +108,7 @@ def get_ion_insertion_wf(
106108

107109
# Apply the vasp powerup if present
108110
if vasp_powerups is not None:
109-
wf = powerup_by_kwargs(wf, **vasp_powerups)
111+
wf = powerup_by_kwargs(wf, vasp_powerups)
110112
for fw in wf.fws:
111113
fw.spec["vasp_powerups"] = vasp_powerups
112114

atomate/vasp/workflows/tests/test_insertion_workflow.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
wf_dir = ref_dir / "insertion_wf"
2222

2323
VASP_CMD = None # for fake VASP
24+
DEBUG_MODE = (
25+
False # If true, retains the database and output dirs at the end of the test
26+
)
2427

2528

2629
class TestInsertionWorkflow(AtomateTest):
@@ -38,19 +41,23 @@ def setUp(self):
3841
working_ion="Mg",
3942
volumetric_data_type="AECCAR",
4043
db_file=db_dir / "db.json",
41-
vasp_powerups={
42-
"add_modify_incar": {
43-
"modify_incar_params": {"incar_update": {"KPAR": 8}}
44+
vasp_powerups=[
45+
{
46+
"powerup_name": "add_modify_incar",
47+
"kwargs": {"modify_incar_params": {"incar_update": {"KPAR": 8}}},
4448
},
45-
"use_fake_vasp": {
46-
"ref_dirs": calc_dirs,
47-
"check_incar": False,
48-
"check_kpoints": False,
49-
"check_poscar": False,
50-
"check_potcar": False,
49+
{
50+
"powerup_name": "use_fake_vasp",
51+
"kwargs": {
52+
"ref_dirs": calc_dirs,
53+
"check_incar": False,
54+
"check_kpoints": False,
55+
"check_poscar": False,
56+
"check_potcar": False,
57+
},
5158
},
52-
"use_potcar_spec": {},
53-
},
59+
{"powerup_name": "use_potcar_spec", "kwargs": {}},
60+
],
5461
optimizefw_kwargs={"ediffg": -0.05},
5562
)
5663
wf = use_fake_vasp(

0 commit comments

Comments
 (0)