@@ -71,17 +71,6 @@ def format_E0s(atomic_energies: dict) -> dict | str:
7171 return "average"
7272
7373
74- def get_bash_command (root : Path , pre : Optional [str ] = None ) -> str :
75- defaults = [
76- "mkdir checkpoints" , # otherwise MACE borks
77- "mace_run_train --config {}" , # run MACE
78- f"rsync -av --ignore-existing --exclude=/*.model ./ { root } /" , # copy things back
79- ]
80- if pre :
81- defaults .insert (1 , pre )
82- return "\n " .join (defaults )
83-
84-
8574def _execute (
8675 bash_template : str ,
8776 inputs : list [File ],
@@ -176,19 +165,49 @@ def _resolve_futures(self) -> AppFuture:
176165 cfg = self .config | {KEY_ATOMIC_ENERGIES : self .atomic_energies }
177166 return resolve_nested_futures (cfg )
178167
179- @staticmethod
180- def get_app (init : bool = False ) -> tuple [Callable , str ]:
181- # TODO: update -- why staticmethod and not function?
168+ def _execute_app (self , config : dict ) -> AppFuture :
169+ """"""
182170 context = psiflow .context ()
183171 definition = context .definitions ["ModelTraining" ]
184- template = context .bash_template
172+ resources = definition .wq_resources ()
173+
174+ # final config tweaks
175+ if definition .multi_gpu :
176+ config ["distributed" ] = True
177+ config ["launcher" ] = "torchrun"
178+ else :
179+ config ["distributed" ] = False
180+ file = psiflow .context ().new_file ("mace_cfg_" , ".yaml" )
181+ yaml .safe_dump (config , open (file .filepath , "w" ))
182+
183+ # construct MACE train script
184+ command = "$(which mace_run_train) --config {}"
185+ if config ["distributed" ]:
186+ command = f"torchrun --standalone --nnodes=1 --nproc_per_node={ resources ['gpus' ]} { command } "
187+ command = definition .wrap_in_timeout (command )
188+
189+ pre_copy = ""
190+ if self .has_checkpoint : # restart from latest checkpoint
191+ chkpt_out = f"checkpoints/{ create_checkpoint_name (config )} "
192+ pre_copy = f"cp { self .path_chkpt } { chkpt_out } "
193+
194+ command_lines = [
195+ "mkdir checkpoints" , # otherwise MACE borks
196+ pre_copy ,
197+ command ,
198+ f"rsync -av --ignore-existing --exclude=/*.model ./ { self .root } /" , # copy things back
199+ ]
200+ command = "\n " .join ([l for l in command_lines if l ])
201+
202+ execute_app = bash_app (_execute , executors = ["ModelTraining" ])
185203 env_vars = format_env_vars (definition .env_vars )
186- app = partial (
187- bash_app (_execute , executors = ["ModelTraining" ]),
188- parsl_resource_specification = definition .wq_resources (),
189- label = "mace_init" if init else "mace_train" ,
204+ future = execute_app (
205+ bash_template = context .bash_template .format (commands = command , env = env_vars ),
206+ inputs = [file ],
207+ parsl_resource_specification = resources ,
208+ label = "mace_init" if config ["name" ] == "init" else "mace_train" ,
190209 )
191- return app , template . format ( commands = "{commands}" , env = env_vars )
210+ return future
192211
193212 @property
194213 def path_config (self ) -> Path :
@@ -256,13 +275,8 @@ def initialize_app(
256275 "valid_file" : file_val .filepath ,
257276 }
258277 cfg ["E0s" ] = format_E0s (model .atomic_energies )
259- file_cfg = psiflow .context ().new_file ("mace_cfg_" , ".yaml" )
260- yaml .safe_dump (cfg , open (file_cfg .filepath , "w" ))
261-
262- command = get_bash_command (model .root )
263- app , template = model .get_app (init = True )
264278
265- future = app ( template . format ( commands = command ), inputs = [ file_cfg ] )
279+ future = model . _execute_app ( cfg )
266280 inputs = [future .stdout , future .stderr , future ]
267281 future_ = process_output (model , config , inputs = inputs )
268282 return future_
@@ -292,17 +306,8 @@ def train_app(
292306 "valid_file" : file_val .filepath ,
293307 }
294308 cfg ["E0s" ] = format_E0s (cfg .pop (KEY_ATOMIC_ENERGIES ))
295- file_cfg = psiflow .context ().new_file ("mace_cfg_" , ".yaml" )
296- yaml .safe_dump (cfg , open (file_cfg .filepath , "w" ))
297-
298- pre = None
299- if model .has_checkpoint :
300- chkpt_out = f"checkpoints/{ create_checkpoint_name (cfg )} "
301- pre = f"cp { model .path_chkpt } { chkpt_out } "
302- command = get_bash_command (model .root , pre = pre )
303- app , template = model .get_app (init = False )
304309
305- future = app ( template . format ( commands = command ), inputs = [ file_cfg ] )
310+ future = model . _execute_app ( cfg )
306311 inputs = [future .stdout , future .stderr , future ]
307312 future_ = process_output (model , config , inputs = inputs )
308313 return future_
0 commit comments