Skip to content

Commit 0d490f7

Browse files
phlrainmaxiaolong001
authored andcommitted
optimize dual pp overlap (PaddlePaddle#74527)
* optimize dual pp overlap * polish code * polish code
1 parent 94c32b4 commit 0d490f7

File tree

4 files changed

+40
-2
lines changed

4 files changed

+40
-2
lines changed

paddle/fluid/distributed/collective/deep_ep/include/event_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace deep_ep::detail {
2222

2323
class EventPool {
2424
public:
25-
EventPool() = default;
25+
EventPool();
2626
EventPool(const EventPool&) = delete;
2727
EventPool(EventPool&&) = delete;
2828
~EventPool();

paddle/fluid/distributed/collective/deep_ep/src/event_pool.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ EventPool &EventPool::Instance() {
2222
return pool;
2323
}
2424

25+
EventPool::EventPool() {
26+
for (size_t i = 0; i < 1000; ++i) {
27+
cudaEvent_t new_event;
28+
CUDA_CHECK(cudaEventCreate(&new_event));
29+
30+
cudaEventRecord(new_event, 0);
31+
incomplished_events_.push(new_event);
32+
}
33+
}
34+
2535
EventPool::~EventPool() {
2636
const auto &DestroyEvent = [](cudaEvent_t event) {
2737
cudaError_t e = cudaEventDestroy(event);

python/paddle/distributed/fleet/meta_parallel/dualpipev.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
PipelineParallel,
3838
)
3939
from .pp_utils.batch_comm_helper import BatchCommHelper
40-
from .zero_bubble_utils import WeightGradStore
40+
from .zero_bubble_utils import EventStore, WeightGradStore
4141

4242
__all__ = []
4343

@@ -358,6 +358,10 @@ def _commit_and_wait_comm(
358358
else 0
359359
)
360360
if common_forward_ops_num == 0 and common_backward_ops_num == 0:
361+
if EventStore.event is not None:
362+
e_t = EventStore.event
363+
EventStore.event = None
364+
return e_t
361365
return deep_ep.get_event_from_custom_stream(
362366
paddle.device.current_stream().stream_base
363367
)
@@ -387,13 +391,28 @@ def _commit_and_wait_comm(
387391
pp_raw_stream
388392
)
389393

394+
backward_outer_event_wait = False
395+
if EventStore.event is not None:
396+
with paddle.device.stream_guard(
397+
paddle.device.Stream(stream_base=pp_raw_stream)
398+
):
399+
EventStore.event.current_stream_wait()
400+
401+
EventStore.set(None)
402+
self.pp_group.process_group.set_outer_wait(True)
403+
404+
backward_outer_event_wait = True
405+
390406
if common_backward_ops_num > 0:
391407
bwd_reqs = batch_isend_irecv(self.comm_backward_ops)
392408

393409
if not use_stream_wait_event:
394410
for req in bwd_reqs:
395411
req.wait()
396412

413+
if backward_outer_event_wait:
414+
self.pp_group.process_group.set_outer_wait(False)
415+
397416
if use_stream_wait_event:
398417
forward_event_to_wait.current_stream_wait()
399418

python/paddle/distributed/fleet/meta_parallel/zero_bubble_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def clear(cls) -> None:
5454
cls.funcs_queue = queue.Queue()
5555

5656

57+
class EventStore:
58+
59+
event = None
60+
61+
@classmethod
62+
def set(cls, event) -> None:
63+
cls.event = event
64+
65+
5766
def fold_init_dims(tensor):
5867
# NOTE(zhangyuqin1998): Reshape a rank-3 tensor from P x M x N to (P * M) x N,
5968
# to keep weight_grad in a correct rank. See phi::FoldInitDims.

0 commit comments

Comments
 (0)