Skip to content

Commit 35783bb

Browse files
update auto ddp code
1 parent 78fe1b8 commit 35783bb

File tree

2 files changed

+85
-27
lines changed

2 files changed

+85
-27
lines changed

deepmd/pd/model/model/make_model.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from deepmd.utils.path import (
3838
DPPath,
3939
)
40+
import paddle.distributed as dist
41+
from paddle.distributed import fleet
42+
import functools
4043

4144

4245
def make_model(T_AtomicModel: type[BaseAtomicModel]):
@@ -163,29 +166,60 @@ def forward_common(
163166
coord, box=box, fparam=fparam, aparam=aparam
164167
)
165168
del coord, box, fparam, aparam
169+
# (
170+
# extended_coord,
171+
# extended_atype,
172+
# mapping,
173+
# nlist,
174+
# ) = extend_input_and_build_neighbor_list(
175+
# cc,
176+
# atype,
177+
# self.get_rcut(),
178+
# self.get_sel(),
179+
# # types will be distinguished in the lower interface,
180+
# # so it doesn't need to be distinguished here
181+
# mixed_types=True,
182+
# box=bb,
183+
# )
184+
wrapped_func_1 = dist.local_map(
185+
func=lambda a,b,c: extend_input_and_build_neighbor_list(a,b,self.get_rcut(), self.get_sel(), True, c),
186+
in_placements=[ele.placements for ele in [cc, atype, bb]],
187+
out_placements=[[dist.Shard(0)] for _ in range(4)],
188+
process_mesh=fleet.auto.get_mesh()
189+
)
190+
166191
(
167192
extended_coord,
168193
extended_atype,
169194
mapping,
170195
nlist,
171-
) = extend_input_and_build_neighbor_list(
196+
) = wrapped_func_1(
172197
cc,
173198
atype,
174-
self.get_rcut(),
175-
self.get_sel(),
176-
# types will be distinguished in the lower interface,
177-
# so it doesn't need to be distinguished here
178-
mixed_types=True,
179-
box=bb,
199+
bb,
200+
)
201+
# model_predict_lower = self.forward_common_lower(
202+
# extended_coord,
203+
# extended_atype,
204+
# nlist,
205+
# mapping,
206+
# do_atomic_virial=do_atomic_virial,
207+
# fparam=fp,
208+
# aparam=ap,
209+
# )
210+
211+
wrapped_func_2 = dist.local_map(
212+
func=functools.partial(self.forward_common_lower, do_atomic_virial=do_atomic_virial, fparam=fp, aparam=ap),
213+
in_placements=[ele.placements for ele in [extended_coord, extended_atype, nlist, mapping]],
214+
out_placements=[[dist.Shard(0)] for _ in range(6)],
215+
process_mesh=fleet.auto.get_mesh(),
216+
reshard_inputs=True
180217
)
181-
model_predict_lower = self.forward_common_lower(
218+
model_predict_lower = wrapped_func_2(
182219
extended_coord,
183220
extended_atype,
184221
nlist,
185222
mapping,
186-
do_atomic_virial=do_atomic_virial,
187-
fparam=fp,
188-
aparam=ap,
189223
)
190224
model_predict = communicate_extended_output(
191225
model_predict_lower,

deepmd/pd/train/training.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from paddle.io import (
2727
DataLoader,
2828
)
29-
29+
import paddle.distributed as dist
30+
from paddle.distributed import fleet
31+
import functools
3032
from deepmd.common import (
3133
symlink_prefix_files,
3234
)
@@ -101,6 +103,11 @@ def __init__(
101103
Args:
102104
- config: The Dict-like configuration with training options.
103105
"""
106+
from paddle.distributed import fleet
107+
mesh_dims = [("dp", 32)]
108+
fleet.auto.create_mesh(mesh_dims)
109+
fleet.init(is_collective=True)
110+
104111
enable_prim(True)
105112
if init_model is not None:
106113
resume_model = init_model
@@ -748,22 +755,39 @@ def step(_step_id, task_key="Default") -> None:
748755
if self.world_size > 1
749756
else contextlib.nullcontext
750757
)
751-
with sync_context():
752-
with nvprof_context(enable_profiling, "Forward pass"):
753-
model_pred, loss, more_loss = self.wrapper(
754-
**input_dict,
755-
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
756-
label=label_dict,
757-
task_key=task_key,
758-
)
759-
760-
with nvprof_context(enable_profiling, "Backward pass"):
761-
loss.backward()
758+
759+
# with sync_context():
760+
# with nvprof_context(enable_profiling, "Forward pass"):
761+
# model_pred, loss, more_loss = self.wrapper(
762+
# **input_dict,
763+
# cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
764+
# label=label_dict,
765+
# task_key=task_key,
766+
# )
767+
768+
# with nvprof_context(enable_profiling, "Backward pass"):
769+
# loss.backward()
770+
771+
# if self.world_size > 1:
772+
# # fuse + allreduce manually before optimization if use DDP + no_sync
773+
# # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
774+
# hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
775+
776+
with nvprof_context(enable_profiling, "Forward pass"):
777+
for __key in ('coord', 'atype', 'box'):
778+
input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
779+
for __key, _ in label_dict.items():
780+
if isinstance(label_dict[__key], paddle.Tensor):
781+
label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
782+
model_pred, loss, more_loss = self.wrapper(
783+
**input_dict,
784+
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
785+
label=label_dict,
786+
task_key=task_key,
787+
)
762788

763-
if self.world_size > 1:
764-
# fuse + allreduce manually before optimization if use DDP + no_sync
765-
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
766-
hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
789+
with nvprof_context(enable_profiling, "Backward pass"):
790+
loss.backward()
767791

768792
if self.gradient_max_norm > 0.0:
769793
with nvprof_context(enable_profiling, "Gradient clip"):

0 commit comments

Comments
 (0)