@@ -1983,7 +1983,10 @@ Buffer::low_latency_dispatch_two_stage(
1983
1983
auto next_buffer = layout.buffers [low_latency_buffer_idx ^= 1 ];
1984
1984
1985
1985
// Wait previous tasks to be finished
1986
- auto launch_stream = calc_ctx->stream ();
1986
+ auto compute_stream = calc_ctx->stream ();
1987
+ auto launch_stream = async ? comm_stream : compute_stream;
1988
+ EP_HOST_ASSERT (!(async && return_recv_hook));
1989
+
1987
1990
auto return_x_dtype = phi::DataType::BFLOAT16;
1988
1991
if (use_fp8) {
1989
1992
return_x_dtype = phi::DataType::FLOAT8_E4M3FN;
@@ -2084,11 +2087,16 @@ Buffer::low_latency_dispatch_two_stage(
2084
2087
phases,
2085
2088
low_latency_buffer_idx);
2086
2089
};
2087
- // TODO(Zhenyu Li): supports async/return_recv_hook
2088
- launcher ((LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2089
- // Wait streams
2090
+ launcher (return_recv_hook
2091
+ ? LOW_LATENCY_SEND_PHASE
2092
+ : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2093
+ // Async event
2090
2094
std::optional<EventHandle> event;
2095
+ if (async) {
2096
+ event = EventHandle (launch_stream);
2097
+ }
2091
2098
std::optional<std::function<void ()>> recv_hook = std::nullopt;
2099
+ if (return_recv_hook) recv_hook = [=]() { launcher (LOW_LATENCY_RECV_PHASE); };
2092
2100
return {packed_recv_x,
2093
2101
packed_recv_x_scales,
2094
2102
packed_rdma_recv_x,
@@ -2158,7 +2166,9 @@ Buffer::low_latency_combine_two_stage(
2158
2166
auto buffer = layout.buffers [low_latency_buffer_idx];
2159
2167
auto next_buffer = layout.buffers [low_latency_buffer_idx ^= 1 ];
2160
2168
2161
- auto launch_stream = calc_ctx->stream ();
2169
+ auto compute_stream = calc_ctx->stream ();
2170
+ auto launch_stream = async ? comm_stream : compute_stream;
2171
+ EP_HOST_ASSERT (!(async && return_recv_hook));
2162
2172
2163
2173
// Allocate output tensor
2164
2174
deep_ep::detail::Tensor combined_x;
@@ -2204,12 +2214,17 @@ Buffer::low_latency_combine_two_stage(
2204
2214
dispatch_use_fp8,
2205
2215
low_latency_buffer_idx);
2206
2216
};
2207
- // TODO(Zhenyu Li): supports async/return_recv_hook
2208
- launcher ((LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2209
- // Wait streams
2217
+ launcher (return_recv_hook
2218
+ ? LOW_LATENCY_SEND_PHASE
2219
+ : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
2220
+ // Async event
2210
2221
std::optional<EventHandle> event;
2222
+ if (async) {
2223
+ event = EventHandle (launch_stream);
2224
+ }
2211
2225
// Receiver callback
2212
2226
std::optional<std::function<void ()>> recv_hook = std::nullopt;
2227
+ if (return_recv_hook) recv_hook = [=]() { launcher (LOW_LATENCY_RECV_PHASE); };
2213
2228
// Return values
2214
2229
return {combined_x, event, recv_hook};
2215
2230
}
0 commit comments