Skip to content

Commit 09dc8dc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f22feaf commit 09dc8dc

File tree

4 files changed

+37
-22
lines changed

4 files changed

+37
-22
lines changed

deepmd/pd/loss/ener.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55

66
import paddle
7+
import paddle.distributed as dist
78
import paddle.nn.functional as F
89

910
from deepmd.pd.loss.loss import (
@@ -21,7 +22,6 @@
2122
from deepmd.utils.version import (
2223
check_version_compatibility,
2324
)
24-
import paddle.distributed as dist
2525

2626

2727
def custom_huber_loss(predictions, targets, delta=1.0):
@@ -206,10 +206,9 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
206206
find_energy = label.get("find_energy", 0.0)
207207
pref_e = pref_e * find_energy
208208
if not self.use_l1_all:
209-
210209
logit = energy_pred - energy_label
211210
logit = dist.reshard(tmp, tmp.process_mesh, [dist.Replicate()])
212-
211+
213212
l2_ener_loss = paddle.mean(paddle.square(logit))
214213
if not self.inference:
215214
more_loss["l2_ener_loss"] = self.display_if_exist(
@@ -264,7 +263,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
264263
force_label = label["force"]
265264
diff_f = (force_label - force_pred).reshape([-1])
266265
diff_f = dist.reshard(diff_f, diff_f.process_mesh, [dist.Replicate()])
267-
266+
268267
if self.relative_f is not None:
269268
force_label_3 = force_label.reshape([-1, 3])
270269
norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f

deepmd/pd/model/model/make_model.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
23
from typing import (
34
Optional,
45
)
56

67
import paddle
8+
import paddle.distributed as dist
9+
from paddle.distributed import (
10+
fleet,
11+
)
712

813
from deepmd.dpmodel import (
914
ModelOutputDef,
@@ -37,9 +42,6 @@
3742
from deepmd.utils.path import (
3843
DPPath,
3944
)
40-
import paddle.distributed as dist
41-
from paddle.distributed import fleet
42-
import functools
4345

4446

4547
def make_model(T_AtomicModel: type[BaseAtomicModel]):
@@ -182,10 +184,12 @@ def forward_common(
182184
# box=bb,
183185
# )
184186
wrapped_func_1 = dist.local_map(
185-
func=lambda a,b,c: extend_input_and_build_neighbor_list(a,b,self.get_rcut(), self.get_sel(), True, c),
187+
func=lambda a, b, c: extend_input_and_build_neighbor_list(
188+
a, b, self.get_rcut(), self.get_sel(), True, c
189+
),
186190
in_placements=[ele.placements for ele in [cc, atype, bb]],
187191
out_placements=[[dist.Shard(0)] for _ in range(4)],
188-
process_mesh=fleet.auto.get_mesh()
192+
process_mesh=fleet.auto.get_mesh(),
189193
)
190194

191195
(
@@ -209,11 +213,19 @@ def forward_common(
209213
# )
210214

211215
wrapped_func_2 = dist.local_map(
212-
func=functools.partial(self.forward_common_lower, do_atomic_virial=do_atomic_virial, fparam=fp, aparam=ap),
213-
in_placements=[ele.placements for ele in [extended_coord, extended_atype, nlist, mapping]],
216+
func=functools.partial(
217+
self.forward_common_lower,
218+
do_atomic_virial=do_atomic_virial,
219+
fparam=fp,
220+
aparam=ap,
221+
),
222+
in_placements=[
223+
ele.placements
224+
for ele in [extended_coord, extended_atype, nlist, mapping]
225+
],
214226
out_placements=[[dist.Shard(0)] for _ in range(6)],
215227
process_mesh=fleet.auto.get_mesh(),
216-
reshard_inputs=True
228+
reshard_inputs=True,
217229
)
218230
model_predict_lower = wrapped_func_2(
219231
extended_coord,

deepmd/pd/train/training.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,13 @@
1919
from paddle.distributed import (
2020
fleet,
2121
)
22-
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
2322
from paddle.framework import (
2423
core,
2524
)
2625
from paddle.io import (
2726
DataLoader,
2827
)
29-
import paddle.distributed as dist
30-
from paddle.distributed import fleet
31-
import functools
28+
3229
from deepmd.common import (
3330
symlink_prefix_files,
3431
)
@@ -103,7 +100,6 @@ def __init__(
103100
Args:
104101
- config: The Dict-like configuration with training options.
105102
"""
106-
from paddle.distributed import fleet
107103
mesh_dims = [("dp", 32)]
108104
fleet.auto.create_mesh(mesh_dims)
109105
fleet.init(is_collective=True)
@@ -753,7 +749,7 @@ def step(_step_id, task_key="Default") -> None:
753749
if self.world_size > 1
754750
else contextlib.nullcontext
755751
)
756-
752+
757753
# with sync_context():
758754
# with nvprof_context(enable_profiling, "Forward pass"):
759755
# model_pred, loss, more_loss = self.wrapper(
@@ -772,11 +768,19 @@ def step(_step_id, task_key="Default") -> None:
772768
# hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
773769

774770
with nvprof_context(enable_profiling, "Forward pass"):
775-
for __key in ('coord', 'atype', 'box'):
776-
input_dict[__key] = dist.shard_tensor(input_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
771+
for __key in ("coord", "atype", "box"):
772+
input_dict[__key] = dist.shard_tensor(
773+
input_dict[__key],
774+
mesh=dist.get_mesh(),
775+
placements=[dist.Shard(0)],
776+
)
777777
for __key, _ in label_dict.items():
778778
if isinstance(label_dict[__key], paddle.Tensor):
779-
label_dict[__key] = dist.shard_tensor(label_dict[__key], mesh=dist.get_mesh(), placements=[dist.Shard(0)])
779+
label_dict[__key] = dist.shard_tensor(
780+
label_dict[__key],
781+
mesh=dist.get_mesh(),
782+
placements=[dist.Shard(0)],
783+
)
780784
model_pred, loss, more_loss = self.wrapper(
781785
**input_dict,
782786
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),

examples/water/dpa3/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ HDFS_USE_FILE_LOCKING=0 python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,
1414
# python -m paddle.distributed.launch \
1515
# --gpus=0,1,2,3 \
1616
# --ips=10.67.200.17,10.67.200.11,10.67.200.13,10.67.200.15 \
17-
# dp --pd train input_torch.json -l dp_train.log
17+
# dp --pd train input_torch.json -l dp_train.log

0 commit comments

Comments
 (0)