Skip to content

Commit 918ccdb

Browse files
[Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <[email protected]>
1 parent 9845f0d commit 918ccdb

22 files changed

+1839
-344
lines changed

custom_ops/gpu_ops/update_inputs_v1.cu

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
3232
const int max_bsz,
3333
const int input_ids_stride,
3434
const int block_num_per_seq,
35-
const int block_size) {
35+
const int block_size,
36+
bool prefill_one_step_stop) {
3637
int thread_idx = threadIdx.x;
3738
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
3839
__shared__ typename BlockReduce::TempStorage temp_storage;
@@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
5455
seq_lens_encoder[thread_idx] = 0;
5556
} else {
5657
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
57-
// decoding
58-
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
59-
seq_lens_this_time[thread_idx] = 1;
60-
seq_lens_encoder[thread_idx] = 0;
61-
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
62-
input_ids_now[0] = next_tokens[thread_idx];
58+
if (prefill_one_step_stop) {
59+
// prefill done, stop
60+
stop_flags[thread_idx] = true;
61+
seq_lens_this_time[thread_idx] = 0;
62+
seq_lens_decoder[thread_idx] = 0;
63+
seq_lens_encoder[thread_idx] = 0;
64+
stop_flag_now_int = 1;
65+
} else{
66+
// decoding
67+
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
68+
seq_lens_this_time[thread_idx] = 1;
69+
seq_lens_encoder[thread_idx] = 0;
70+
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
71+
input_ids_now[0] = next_tokens[thread_idx];
6372

64-
// to judge whether block is not enough
65-
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
66-
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
67-
// should be scheduled by server
68-
is_block_step[thread_idx] = true;
69-
seq_lens_this_time[thread_idx]= 0;
70-
stop_flags[thread_idx] = true;
71-
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
72-
seq_lens_decoder[thread_idx] = 0;
73-
stop_flag_now_int = 1;
73+
// to judge whether block is not enough
74+
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
75+
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
76+
// should be scheduled by server
77+
is_block_step[thread_idx] = true;
78+
seq_lens_this_time[thread_idx]= 0;
79+
stop_flags[thread_idx] = true;
80+
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
81+
seq_lens_decoder[thread_idx] = 0;
82+
stop_flag_now_int = 1;
83+
}
7484
}
7585
} else
7686
{
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
110120
#else
111121
auto cu_stream = input_ids.stream();
112122
#endif
123+
bool prefill_one_step_stop = false;
124+
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
125+
if (env_p[0] == '1') {
126+
prefill_one_step_stop = true;
127+
}
128+
}
113129
const int max_bsz = stop_flags.shape()[0];
114130
const int now_bsz = seq_lens_this_time.shape()[0];
115131
const int input_ids_stride = input_ids.shape()[1];
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
133149
max_bsz,
134150
input_ids_stride,
135151
block_num_per_seq,
136-
block_size);
152+
block_size,
153+
prefill_one_step_stop);
137154
auto not_need_stop_cpu =
138155
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
139156
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());

0 commit comments

Comments
 (0)