Skip to content

Commit 4702127

Browse files
authored
support dynamic quant for w4afp8 in deepep (PaddlePaddle#76262)
1 parent a88f07c commit 4702127

File tree

7 files changed

+205
-82
lines changed

7 files changed

+205
-82
lines changed

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,7 +1752,8 @@ Buffer::low_latency_dispatch(
17521752
int num_experts,
17531753
bool use_fp8,
17541754
bool async,
1755-
bool return_recv_hook) {
1755+
bool return_recv_hook,
1756+
int num_per_channel) {
17561757
EP_HOST_ASSERT(low_latency_mode);
17571758

17581759
// Tensor checks
@@ -1768,7 +1769,8 @@ Buffer::low_latency_dispatch(
17681769

17691770
auto num_tokens = static_cast<int>(x.size(0)),
17701771
hidden = static_cast<int>(x.size(1));
1771-
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
1772+
auto num_scales = num_per_channel == -1 ? 1 : hidden / 128,
1773+
num_topk = static_cast<int>(topk_idx.size(1));
17721774
int num_local_experts = num_experts / num_ranks;
17731775

17741776
// Buffer control
@@ -1872,7 +1874,8 @@ Buffer::low_latency_dispatch(
18721874
use_fp8,
18731875
workspace,
18741876
launch_stream,
1875-
phases);
1877+
phases,
1878+
num_per_channel);
18761879
};
18771880
launcher(return_recv_hook
18781881
? LOW_LATENCY_SEND_PHASE
@@ -2976,7 +2979,8 @@ Buffer::low_latency_dispatch_api(
29762979
int num_experts,
29772980
bool use_fp8,
29782981
bool async,
2979-
bool return_recv_hook) {
2982+
bool return_recv_hook,
2983+
int num_per_channel) {
29802984
#ifdef PADDLE_WITH_NVSHMEM
29812985
const auto& x_ = ConvertPaddleTensorToDetailTensor(x);
29822986
const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx);
@@ -2994,7 +2998,8 @@ Buffer::low_latency_dispatch_api(
29942998
num_experts,
29952999
use_fp8,
29963000
async,
2997-
return_recv_hook);
3001+
return_recv_hook,
3002+
num_per_channel);
29983003

29993004
auto packed_recv_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res));
30003005

paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ struct Buffer {
279279
int num_experts,
280280
bool use_fp8,
281281
bool async,
282-
bool return_recv_hook);
282+
bool return_recv_hook,
283+
int num_per_channel);
283284

284285
std::tuple<deep_ep::detail::Tensor,
285286
std::optional<EventHandle>,
@@ -452,7 +453,8 @@ struct Buffer {
452453
int num_experts,
453454
bool use_fp8,
454455
bool async,
455-
bool return_recv_hook);
456+
bool return_recv_hook,
457+
int num_per_channel);
456458

457459
std::tuple<paddle::Tensor,
458460
std::optional<EventHandle>,

paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ void dispatch(void* packed_recv_x,
318318
bool use_fp8,
319319
void* workspace,
320320
cudaStream_t stream,
321-
int phases);
321+
int phases,
322+
int num_per_channel);
322323

323324
void combine(void* combined_x,
324325
void* rdma_recv_x,

0 commit comments

Comments
 (0)