Skip to content

Commit 698721e

Browse files
authored
[Infercne] make internode_ll_two_stage supports async and hook mode. (#74405)
1 parent 54bf220 commit 698721e

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,7 +1983,10 @@ Buffer::low_latency_dispatch_two_stage(
19831983
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
19841984

19851985
// Wait previous tasks to be finished
1986-
auto launch_stream = calc_ctx->stream();
1986+
auto compute_stream = calc_ctx->stream();
1987+
auto launch_stream = async ? comm_stream : compute_stream;
1988+
EP_HOST_ASSERT(!(async && return_recv_hook));
1989+
19871990
auto return_x_dtype = phi::DataType::BFLOAT16;
19881991
if (use_fp8) {
19891992
return_x_dtype = phi::DataType::FLOAT8_E4M3FN;
@@ -2084,11 +2087,16 @@ Buffer::low_latency_dispatch_two_stage(
20842087
phases,
20852088
low_latency_buffer_idx);
20862089
};
2087-
// TODO(Zhenyu Li): supports async/return_recv_hook
2088-
launcher((LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2089-
// Wait streams
2090+
launcher(return_recv_hook
2091+
? LOW_LATENCY_SEND_PHASE
2092+
: (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2093+
// Async event
20902094
std::optional<EventHandle> event;
2095+
if (async) {
2096+
event = EventHandle(launch_stream);
2097+
}
20912098
std::optional<std::function<void()>> recv_hook = std::nullopt;
2099+
if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
20922100
return {packed_recv_x,
20932101
packed_recv_x_scales,
20942102
packed_rdma_recv_x,
@@ -2158,7 +2166,9 @@ Buffer::low_latency_combine_two_stage(
21582166
auto buffer = layout.buffers[low_latency_buffer_idx];
21592167
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
21602168

2161-
auto launch_stream = calc_ctx->stream();
2169+
auto compute_stream = calc_ctx->stream();
2170+
auto launch_stream = async ? comm_stream : compute_stream;
2171+
EP_HOST_ASSERT(!(async && return_recv_hook));
21622172

21632173
// Allocate output tensor
21642174
deep_ep::detail::Tensor combined_x;
@@ -2204,12 +2214,17 @@ Buffer::low_latency_combine_two_stage(
22042214
dispatch_use_fp8,
22052215
low_latency_buffer_idx);
22062216
};
2207-
// TODO(Zhenyu Li): supports async/return_recv_hook
2208-
launcher((LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2209-
// Wait streams
2217+
launcher(return_recv_hook
2218+
? LOW_LATENCY_SEND_PHASE
2219+
: (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2220+
// Async event
22102221
std::optional<EventHandle> event;
2222+
if (async) {
2223+
event = EventHandle(launch_stream);
2224+
}
22112225
// Receiver callback
22122226
std::optional<std::function<void()>> recv_hook = std::nullopt;
2227+
if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
22132228
// Return values
22142229
return {combined_x, event, recv_hook};
22152230
}

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll_two_stage.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ __global__ __launch_bounds__(
148148
num_bytes_per_msg_rdma_revecier_and_nvl_sender % sizeof(int4) == 0);
149149
EP_DEVICE_ASSERT(num_bytes_per_msg_rdma_to_nvl % sizeof(int4) == 0);
150150

151+
// Sending phase
152+
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV;
153+
151154
/* RDMA Sender */
152155
{
153156
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
@@ -363,6 +366,10 @@ __global__ __launch_bounds__(
363366
}
364367
}
365368

369+
// Receiving phase
370+
LOW_LATENCY_DISPATCH_RECV:
371+
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return;
372+
366373
/* RDMA Receiver and NVL Sender */
367374
{
368375
const int sms_per_rdma = num_sms / kNumRdmaRanks;
@@ -828,6 +835,9 @@ __global__ __launch_bounds__(
828835
const size_t NVL_BUFFER_OFFSET =
829836
nvl_buffer_id * NVL_BUFFER_X_BYTES_PER_BUFFER;
830837

838+
// Sending phase
839+
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV;
840+
831841
// Clean up next buffer
832842
if (sm_id == 0) {
833843
#pragma unroll
@@ -1068,6 +1078,10 @@ __global__ __launch_bounds__(
10681078
}
10691079
}
10701080

1081+
// Receiving phase
1082+
LOW_LATENCY_COMBINE_RECV:
1083+
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return;
1084+
10711085
/* RDMA Receiver / RDMA Reducer */
10721086
// Wait all rdma ranks to arrive
10731087
if (sm_id < kNumRdmaRanks) {

0 commit comments

Comments
 (0)