|
| 1 | +from fireworks import FiretaskBase, FWAction, explicit_serialize, Workflow |
| 2 | +from atomate.utils.utils import env_chk |
| 3 | +from atomate.vasp.database import VaspCalcDb |
| 4 | +from atomate.vasp.fireworks.approx_neb import ImageFW |
| 5 | +from atomate.common.powerups import powerup_by_kwargs |
| 6 | + |
| 7 | +__author__ = "Ann Rutt" |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | +@explicit_serialize |
| 12 | +class GetImageFireworks(FiretaskBase): |
| 13 | + """ |
| 14 | + Adds ImageFWs to the workflow for the provided images_key |
| 15 | + according to the scheme specified by launch_mode. Optional |
| 16 | + parameters such as "handler_group", "add_additional_fields", |
| 17 | + and "add_tags" can be used to modify the resulting ImageFWs. |
| 18 | +
|
| 19 | + Args: |
| 20 | + db_file (str): path to file containing the database |
| 21 | + credentials. |
| 22 | + approx_neb_wf_uuid (str): unique id for approx neb workflow |
| 23 | + record keeping. |
| 24 | + images_key (str): specifies a key corresponding the images |
| 25 | + field of the approx_neb collection which specifies the |
| 26 | + desired combination of end points to interpolate images |
| 27 | + between. images_key should be a string of format "0+1", |
| 28 | + "0+2", etc. matching end_points_combo input of |
| 29 | + PathfinderToDb Firetask or pathfinder_key input of |
| 30 | + AddSelectiveDynamics Firetask. If images_key is not |
| 31 | + provided images will be launched for all paths/keys in |
| 32 | + the approx_neb collection images field. |
| 33 | + launch_mode (str): "all" or "screening" |
| 34 | + vasp_cmd (str): the name of the full executable for running |
| 35 | + VASP. |
| 36 | + Optional Params: |
| 37 | + vasp_input_set (VaspInputSet class): can use to |
| 38 | + define VASP input parameters. |
| 39 | + See pymatgen.io.vasp.sets module for more |
| 40 | + information. MPRelaxSet() and |
| 41 | + override_default_vasp_params are used if |
| 42 | + vasp_input_set = None. |
| 43 | + override_default_vasp_params (dict): if provided, |
| 44 | + vasp_input_set is disregarded and the Vasp Input |
| 45 | + Set is created by passing |
| 46 | + override_default_vasp_params to MPRelaxSet(). |
| 47 | + Allows for easy modification of MPRelaxSet(). |
| 48 | + For example, to set ISIF=2 in the INCAR use: |
| 49 | + {"user_incar_settings":{"ISIF":2}} |
| 50 | + handler_group (str or [ErrorHandler]): group of handlers to |
| 51 | + use for RunVaspCustodian firetask. See handler_groups |
| 52 | + dict in the code for the groups and complete list of |
| 53 | + handlers in each group. Alternatively, you can specify a |
| 54 | + list of ErrorHandler objects. |
| 55 | + add_additional_fields (dict): dict of additional fields to |
| 56 | + add to task docs (by additional_fields of VaspToDb). |
| 57 | + add_tags (list of strings): added to the "tags" field of the |
| 58 | + task docs. |
| 59 | + """ |
| 60 | + |
| 61 | + required_params = [ |
| 62 | + "db_file", |
| 63 | + "approx_neb_wf_uuid", |
| 64 | + "images_key", |
| 65 | + "launch_mode", |
| 66 | + "vasp_cmd", |
| 67 | + ] |
| 68 | + optional_params = [ |
| 69 | + "vasp_input_set", |
| 70 | + "override_default_vasp_params", |
| 71 | + "handler_group", |
| 72 | + "add_additional_fields", |
| 73 | + "add_tags", |
| 74 | + ] |
| 75 | + |
| 76 | + def run_task(self, fw_spec): |
| 77 | + # get the database connection |
| 78 | + db_file = env_chk(self["db_file"], fw_spec) |
| 79 | + mmdb = VaspCalcDb.from_db_file(db_file, admin=True) |
| 80 | + mmdb.collection = mmdb.db["approx_neb"] |
| 81 | + wf_uuid = self["approx_neb_wf_uuid"] |
| 82 | + launch_mode = self["launch_mode"] |
| 83 | + images_key = self["images_key"] |
| 84 | + |
| 85 | + approx_neb_doc = mmdb.collection.find_one({"wf_uuid": wf_uuid}, {"images": 1}) |
| 86 | + all_images = approx_neb_doc["images"] |
| 87 | + |
| 88 | + # get structure_path of desired images and sort into structure_paths |
| 89 | + if images_key and isinstance(all_images, (dict)): |
| 90 | + images = all_images[images_key] |
| 91 | + max_n = len(images) |
| 92 | + if launch_mode == "all": |
| 93 | + structure_paths = [ |
| 94 | + "images." + images_key + "." + str(n) + ".input_structure" |
| 95 | + for n in range(0, max_n) |
| 96 | + ] |
| 97 | + elif launch_mode == "screening": |
| 98 | + structure_paths = self.get_and_sort_paths( |
| 99 | + max_n=max_n, images_key=images_key |
| 100 | + ) |
| 101 | + elif isinstance(all_images, (dict)): |
| 102 | + structure_paths = dict() |
| 103 | + if launch_mode == "all": |
| 104 | + for key, images in all_images.items(): |
| 105 | + max_n = len(images) |
| 106 | + structure_paths[key] = [ |
| 107 | + "images." + key + "." + str(n) + ".input_structure" |
| 108 | + for n in range(0, max_n) |
| 109 | + ] |
| 110 | + elif launch_mode == "screening": |
| 111 | + for key, images in all_images.items(): |
| 112 | + structure_paths[key] = self.get_and_sort_paths( |
| 113 | + max_n=len(images), images_key=key |
| 114 | + ) |
| 115 | + |
| 116 | + # get list of fireworks to launch |
| 117 | + if isinstance(structure_paths, (list)): |
| 118 | + if isinstance(structure_paths[0], (str)): |
| 119 | + relax_image_fws = [] |
| 120 | + for path in structure_paths: |
| 121 | + relax_image_fws.append(self.get_fw(structure_path=path)) |
| 122 | + else: |
| 123 | + relax_image_fws = self.get_screening_fws(sorted_paths=structure_paths) |
| 124 | + elif isinstance(structure_paths, (dict)): |
| 125 | + relax_image_fws = [] |
| 126 | + if launch_mode == "all": |
| 127 | + for key in structure_paths.keys(): |
| 128 | + for path in structure_paths[key]: |
| 129 | + relax_image_fws.append(self.get_fw(structure_path=path)) |
| 130 | + elif launch_mode == "screening": |
| 131 | + for key in structure_paths.keys(): |
| 132 | + sorted_paths = structure_paths[key] |
| 133 | + relax_image_fws.extend( |
| 134 | + self.get_screening_fws(sorted_paths=sorted_paths) |
| 135 | + ) |
| 136 | + |
| 137 | + # place fws in temporary wf in order to use powerup_by_kwargs |
| 138 | + # to apply powerups to image fireworks |
| 139 | + if "vasp_powerups" in fw_spec.keys(): |
| 140 | + temp_wf = Workflow(relax_image_fws) |
| 141 | + powerup_dicts = fw_spec["vasp_powerups"] |
| 142 | + temp_wf = powerup_by_kwargs(temp_wf, powerup_dicts) |
| 143 | + relax_image_fws = temp_wf.fws |
| 144 | + |
| 145 | + return FWAction(additions=relax_image_fws) |
| 146 | + |
| 147 | + def get_and_sort_paths(self, max_n, images_key=""): |
| 148 | + sorted_paths = [[], [], []] |
| 149 | + mid_n = int(max_n / 2) |
| 150 | + q1 = int((max_n - mid_n) / 2) # for second screening pass |
| 151 | + q3 = int((max_n + mid_n) / 2) # for second screening pass |
| 152 | + |
| 153 | + for n in range(0, max_n): |
| 154 | + path = "images." + images_key + "." + str(n) + ".input_structure" |
| 155 | + if n == mid_n: # path for first screening pass (center image index) |
| 156 | + sorted_paths[0].append(path) |
| 157 | + elif n in [q1, q3]: |
| 158 | + sorted_paths[1].append(path) |
| 159 | + else: |
| 160 | + sorted_paths[-1].append(path) |
| 161 | + |
| 162 | + return sorted_paths |
| 163 | + |
| 164 | + def get_fw(self, structure_path, parents=None): |
| 165 | + add_tags = self.get("add_tags") |
| 166 | + fw = ImageFW( |
| 167 | + approx_neb_wf_uuid=self["approx_neb_wf_uuid"], |
| 168 | + structure_path=structure_path, |
| 169 | + db_file=self["db_file"], |
| 170 | + vasp_input_set=self.get("vasp_input_set"), |
| 171 | + vasp_cmd=self["vasp_cmd"], |
| 172 | + override_default_vasp_params=self.get("override_default_vasp_params"), |
| 173 | + handler_group=self.get("handler_group"), |
| 174 | + parents=parents, |
| 175 | + add_additional_fields=self.get("add_additional_fields"), |
| 176 | + add_tags=add_tags, |
| 177 | + ) |
| 178 | + if isinstance(add_tags, (list)): |
| 179 | + if "tags" in fw.spec.keys(): |
| 180 | + fw.spec["tags"].extend(add_tags) |
| 181 | + else: |
| 182 | + fw.spec["tags"] = add_tags |
| 183 | + return fw |
| 184 | + |
| 185 | + def get_screening_fws(self, sorted_paths): |
| 186 | + if isinstance(sorted_paths, (list)) != True: |
| 187 | + if ( |
| 188 | + any([isinstance(i, (list)) for i in sorted_paths]) != True |
| 189 | + or len(sorted_paths) != 3 |
| 190 | + ): |
| 191 | + raise TypeError("sorted_paths must be a list containing 3 lists") |
| 192 | + |
| 193 | + s1_fw = self.get_fw(structure_path=sorted_paths[0][0]) |
| 194 | + # ToDo: modify this firework to add firetask that checks whether to run/defuse children |
| 195 | + |
| 196 | + s2_fws = [] |
| 197 | + for path in sorted_paths[1]: |
| 198 | + s2_fws.append(self.get_fw(structure_path=path, parents=s1_fw)) |
| 199 | + # ToDo: modify this firework to add firetask that checks whether to run/defuse children |
| 200 | + |
| 201 | + remaining_fws = [] |
| 202 | + for path in sorted_paths[-1]: |
| 203 | + remaining_fws.append(self.get_fw(structure_path=path, parents=s2_fws)) |
| 204 | + |
| 205 | + return [s1_fw] + s2_fws + remaining_fws |
0 commit comments