Skip to content

Commit 0b6d370

Browse files
committed
implement multi-gpu training
we use the torchrun thingie, which limits us to single-node training
1 parent 4f21fc0 commit 0b6d370

File tree

4 files changed

+51
-48
lines changed

4 files changed

+51
-48
lines changed

psiflow/execution.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
)
247247

248248
if self.executor_type == "workqueue":
249-
# WQ-specific checks
249+
# WQ-specific checks TODO: what about multinode?
250250
ensure(
251251
self.kwargs["gpus_per_task"] <= resources["gpus"],
252252
self.kwargs["cores_per_task"] <= resources["cores"],
@@ -541,15 +541,10 @@ def __init__(self, **kwargs) -> None:
541541
"ModelTraining is configured for CPU operation. Is this what you want?"
542542
)
543543

544-
# if self.multigpu:
545-
# # TODO: why? Think this might be a multinode thing - which I do not care about
546-
# message = (
547-
# "the max_training_time keyword does not work "
548-
# "in combination with multi-gpu training. Adjust "
549-
# "the maximum number of epochs to control the "
550-
# "duration of training"
551-
# )
552-
# assert self.max_runtime is None, message
544+
@property
545+
def multi_gpu(self) -> bool:
546+
# only for WQ
547+
return (self.spec or {}).get("gpus", 0) > 1
553548

554549
def train_command(self, initialize: bool = False):
555550
command = "psiflow-mace-train"

psiflow/models/mace.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,6 @@ def format_E0s(atomic_energies: dict) -> dict | str:
7171
return "average"
7272

7373

74-
def get_bash_command(root: Path, pre: Optional[str] = None) -> str:
75-
defaults = [
76-
"mkdir checkpoints", # otherwise MACE borks
77-
"mace_run_train --config {}", # run MACE
78-
f"rsync -av --ignore-existing --exclude=/*.model ./ {root}/", # copy things back
79-
]
80-
if pre:
81-
defaults.insert(1, pre)
82-
return "\n".join(defaults)
83-
84-
8574
def _execute(
8675
bash_template: str,
8776
inputs: list[File],
@@ -176,19 +165,49 @@ def _resolve_futures(self) -> AppFuture:
176165
cfg = self.config | {KEY_ATOMIC_ENERGIES: self.atomic_energies}
177166
return resolve_nested_futures(cfg)
178167

179-
@staticmethod
180-
def get_app(init: bool = False) -> tuple[Callable, str]:
181-
# TODO: update -- why staticmethod and not function?
168+
def _execute_app(self, config: dict) -> AppFuture:
169+
""""""
182170
context = psiflow.context()
183171
definition = context.definitions["ModelTraining"]
184-
template = context.bash_template
172+
resources = definition.wq_resources()
173+
174+
# final config tweaks
175+
if definition.multi_gpu:
176+
config["distributed"] = True
177+
config["launcher"] = "torchrun"
178+
else:
179+
config["distributed"] = False
180+
file = psiflow.context().new_file("mace_cfg_", ".yaml")
181+
yaml.safe_dump(config, open(file.filepath, "w"))
182+
183+
# construct MACE train script
184+
command = "$(which mace_run_train) --config {}"
185+
if config["distributed"]:
186+
command = f"torchrun --standalone --nnodes=1 --nproc_per_node={resources['gpus']} {command}"
187+
command = definition.wrap_in_timeout(command)
188+
189+
pre_copy = ""
190+
if self.has_checkpoint: # restart from latest checkpoint
191+
chkpt_out = f"checkpoints/{create_checkpoint_name(config)}"
192+
pre_copy = f"cp {self.path_chkpt} {chkpt_out}"
193+
194+
command_lines = [
195+
"mkdir checkpoints", # otherwise MACE borks
196+
pre_copy,
197+
command,
198+
f"rsync -av --ignore-existing --exclude=/*.model ./ {self.root}/", # copy things back
199+
]
200+
command = "\n".join([l for l in command_lines if l])
201+
202+
execute_app = bash_app(_execute, executors=["ModelTraining"])
185203
env_vars = format_env_vars(definition.env_vars)
186-
app = partial(
187-
bash_app(_execute, executors=["ModelTraining"]),
188-
parsl_resource_specification=definition.wq_resources(),
189-
label="mace_init" if init else "mace_train",
204+
future = execute_app(
205+
bash_template=context.bash_template.format(commands=command, env=env_vars),
206+
inputs=[file],
207+
parsl_resource_specification=resources,
208+
label="mace_init" if config["name"] == "init" else "mace_train",
190209
)
191-
return app, template.format(commands="{commands}", env=env_vars)
210+
return future
192211

193212
@property
194213
def path_config(self) -> Path:
@@ -256,13 +275,8 @@ def initialize_app(
256275
"valid_file": file_val.filepath,
257276
}
258277
cfg["E0s"] = format_E0s(model.atomic_energies)
259-
file_cfg = psiflow.context().new_file("mace_cfg_", ".yaml")
260-
yaml.safe_dump(cfg, open(file_cfg.filepath, "w"))
261-
262-
command = get_bash_command(model.root)
263-
app, template = model.get_app(init=True)
264278

265-
future = app(template.format(commands=command), inputs=[file_cfg])
279+
future = model._execute_app(cfg)
266280
inputs = [future.stdout, future.stderr, future]
267281
future_ = process_output(model, config, inputs=inputs)
268282
return future_
@@ -292,17 +306,8 @@ def train_app(
292306
"valid_file": file_val.filepath,
293307
}
294308
cfg["E0s"] = format_E0s(cfg.pop(KEY_ATOMIC_ENERGIES))
295-
file_cfg = psiflow.context().new_file("mace_cfg_", ".yaml")
296-
yaml.safe_dump(cfg, open(file_cfg.filepath, "w"))
297-
298-
pre = None
299-
if model.has_checkpoint:
300-
chkpt_out = f"checkpoints/{create_checkpoint_name(cfg)}"
301-
pre = f"cp {model.path_chkpt} {chkpt_out}"
302-
command = get_bash_command(model.root, pre=pre)
303-
app, template = model.get_app(init=False)
304309

305-
future = app(template.format(commands=command), inputs=[file_cfg])
310+
future = model._execute_app(cfg)
306311
inputs = [future.stdout, future.stderr, future]
307312
future_ = process_output(model, config, inputs=inputs)
308313
return future_

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
55

66
[project]
77
name = "psiflow"
8-
version = "4.0.3"
8+
version = "4.0.4"
99
description = "Library for developing interatomic potentials"
1010
readme = "README.md"
1111
requires-python = ">=3.10"

tests/test_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,13 @@ def test_mace_hamiltonian(dataset, mace_foundation):
112112
assert hamiltonian0 != hamiltonian1
113113
hamiltonian2 = psiflow.deserialize(psiflow.serialize(hamiltonian1)).result()
114114
assert hamiltonian0 != hamiltonian2
115+
assert hamiltonian1 == hamiltonian2
116+
hamiltonian2.update_kwargs(enable_cueq=True)
115117

116118
e0 = hamiltonian0.compute(dataset, "energy")
117119
e1 = hamiltonian1.compute(dataset, "energy")
120+
e2 = hamiltonian2.compute(dataset, "energy")
118121
assert np.allclose(e0.result(), e1.result())
119-
pass
122+
assert np.allclose(e0.result(), e2.result())
120123

121124

0 commit comments

Comments
 (0)