Skip to content

Commit d236285

Browse files
refine code
1 parent 318dab0 commit d236285

File tree

3 files changed

+50
-36
lines changed

3 files changed

+50
-36
lines changed

.pre-commit-config.yaml

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ repos:
6565
- id: clang-format
6666
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$)
6767
# markdown, yaml, CSS, javascript
68-
- repo: https://github.com/pre-commit/mirrors-prettier
69-
rev: v4.0.0-alpha.8
70-
hooks:
71-
- id: prettier
72-
types_or: [markdown, yaml, css]
73-
# workflow files cannot be modified by pre-commit.ci
74-
exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
68+
# - repo: https://github.com/pre-commit/mirrors-prettier
69+
# rev: v4.0.0-alpha.8
70+
# hooks:
71+
# - id: prettier
72+
# types_or: [markdown, yaml, css]
73+
# # workflow files cannot be modified by pre-commit.ci
74+
# exclude: ^(source/3rdparty|\.github/workflows|\.clang-format)
7575
# Shell
7676
- repo: https://github.com/scop/pre-commit-shfmt
7777
rev: v3.11.0-1
@@ -83,25 +83,25 @@ repos:
8383
hooks:
8484
- id: cmake-format
8585
#- id: cmake-lint
86-
- repo: https://github.com/njzjz/mirrors-bibtex-tidy
87-
rev: v1.13.0
88-
hooks:
89-
- id: bibtex-tidy
90-
args:
91-
- --curly
92-
- --numeric
93-
- --align=13
94-
- --blank-lines
95-
# disable sort: the order of keys and fields has explict meanings
96-
#- --sort=key
97-
- --duplicates=key,doi,citation,abstract
98-
- --merge=combine
99-
#- --sort-fields
100-
#- --strip-comments
101-
- --trailing-commas
102-
- --encode-urls
103-
- --remove-empty-fields
104-
- --wrap=80
86+
# - repo: https://github.com/njzjz/mirrors-bibtex-tidy
87+
# rev: v1.13.0
88+
# hooks:
89+
# - id: bibtex-tidy
90+
# args:
91+
# - --curly
92+
# - --numeric
93+
# - --align=13
94+
# - --blank-lines
95+
# # disable sort: the order of keys and fields has explict meanings
96+
# #- --sort=key
97+
# - --duplicates=key,doi,citation,abstract
98+
# - --merge=combine
99+
# #- --sort-fields
100+
# #- --strip-comments
101+
# - --trailing-commas
102+
# - --encode-urls
103+
# - --remove-empty-fields
104+
# - --wrap=80
105105
# license header
106106
- repo: https://github.com/Lucas-C/pre-commit-hooks
107107
rev: v1.5.5

deepmd/pd/train/training.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import contextlib
23
import functools
34
import logging
45
import time
@@ -18,6 +19,7 @@
1819
from paddle.distributed import (
1920
fleet,
2021
)
22+
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
2123
from paddle.framework import (
2224
core,
2325
)
@@ -741,16 +743,27 @@ def step(_step_id, task_key="Default") -> None:
741743
pref_lr = _lr.start_lr
742744
else:
743745
pref_lr = cur_lr
744-
with nvprof_context(enable_profiling, "Forward pass"):
745-
model_pred, loss, more_loss = self.wrapper(
746-
**input_dict,
747-
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
748-
label=label_dict,
749-
task_key=task_key,
750-
)
746+
sync_context = (
747+
self.wrapper.no_sync
748+
if self.world_size > 1
749+
else contextlib.nullcontext
750+
)
751+
with sync_context():
752+
with nvprof_context(enable_profiling, "Forward pass"):
753+
model_pred, loss, more_loss = self.wrapper(
754+
**input_dict,
755+
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
756+
label=label_dict,
757+
task_key=task_key,
758+
)
759+
760+
with nvprof_context(enable_profiling, "Backward pass"):
761+
loss.backward()
751762

752-
with nvprof_context(enable_profiling, "Backward pass"):
753-
loss.backward()
763+
if self.world_size > 1:
764+
# fuse + allreduce manually before optimization if use DDP + no_sync
765+
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
766+
hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
754767

755768
if self.gradient_max_norm > 0.0:
756769
with nvprof_context(enable_profiling, "Gradient clip"):

deepmd/pd/utils/env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
ncpus = os.cpu_count()
2828
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(0, ncpus)))
2929
# Make sure DDP uses correct device if applicable
30-
LOCAL_RANK = paddle.distributed.get_rank()
30+
LOCAL_RANK = os.environ.get("PADDLE_LOCAL_RANK")
31+
LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
3132

3233
if os.environ.get("DEVICE") == "cpu" or paddle.device.cuda.device_count() <= 0:
3334
DEVICE = "cpu"

0 commit comments

Comments
 (0)