-
Notifications
You must be signed in to change notification settings - Fork 592
draft for enable prefill in cudagraph #3354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,6 +131,7 @@ def __init__( | |
self.use_cudagraph = self.graph_opt_config.use_cudagraph | ||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) | ||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes | ||
self.cudagraph_capture_prefill = self.graph_opt_config.cudagraph_capture_prefill | ||
|
||
# Initialize share inputs | ||
self._init_share_inputs(self.parallel_config.max_num_seqs) | ||
|
@@ -166,6 +167,15 @@ def exist_prefill(self): | |
else: | ||
return 0 | ||
|
||
def exist_decode(self): | ||
""" | ||
check whether decode stage exist | ||
""" | ||
if int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0: | ||
return 1 | ||
else: | ||
return 0 | ||
|
||
def _init_speculative_proposer(self): | ||
""" | ||
Init speculative proposer | ||
|
@@ -541,6 +551,54 @@ def get_attr_from_request(request, attr, default_value=None): | |
if self.speculative_method in ["mtp"]: | ||
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests) | ||
|
||
def _dummy_prefill_inputs_prefill(self, num_tokens: int, seq_length: int, expected_decode_len: int): | ||
"""Set dummy prefill inputs to share_inputs""" | ||
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token | ||
|
||
batch_size = 1 | ||
max_dec_len = expected_decode_len + 1 | ||
full_length = min( | ||
num_tokens // batch_size, | ||
self.parallel_config.max_model_len - max_dec_len, | ||
) | ||
|
||
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. | ||
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. | ||
if self.fd_config.parallel_config.enable_expert_parallel: | ||
full_length = min(full_length, 32) | ||
|
||
input_length = int(full_length * self.cache_config.kv_cache_ratio) | ||
block_num = ( | ||
input_length + self.cache_config.block_size - 1 | ||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num | ||
|
||
for i in range(batch_size): | ||
idx = i | ||
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) | ||
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) | ||
self.share_inputs["eos_token_id"][:] = np.array( | ||
[2] * self.model_config.eos_tokens_lens, dtype="int64" | ||
).reshape(-1, 1) | ||
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length | ||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length | ||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length | ||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 | ||
self.share_inputs["prompt_lens"][idx : idx + 1] = 0 | ||
self.share_inputs["step_idx"][idx : idx + 1] = 0 | ||
self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len | ||
self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len | ||
self.share_inputs["stop_flags"][idx : idx + 1] = False | ||
self.share_inputs["temperature"][idx : idx + 1] = 1 | ||
|
||
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] | ||
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length | ||
|
||
self.share_inputs["encoder_block_lens"][idx : idx + 1] = block_num | ||
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( | ||
idx * block_num, (idx + 1) * block_num, 1 | ||
) | ||
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer | ||
|
||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): | ||
"""Set dummy prefill inputs to share_inputs""" | ||
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token | ||
|
@@ -909,6 +967,28 @@ def initialize_forward_meta(self): | |
and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seq_lens_encoder 这个tensor 指针会变,把get block shape kernel 输出的另外几个tensor 也打印出来看下 |
||
) | ||
|
||
if self.cudagraph_capture_prefill: | ||
only_prefill_batch = True | ||
decode_exists = None | ||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": | ||
# 收集所有 worker 的状态 | ||
only_prefill_batch_list = [] | ||
decode_exists = self.exist_decode() | ||
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists) | ||
only_prefill_batch = all(only_prefill_batch_list) | ||
self.fd_config.parallel_config.moe_phase.phase = "prefill" if only_prefill_batch else "decode" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么需要改 moe_phase.phase |
||
|
||
self.forward_meta.step_use_cudagraph = ( | ||
self.use_cudagraph | ||
and self.cudagraph_capture_prefill | ||
and only_prefill_batch | ||
and not (decode_exists if decode_exists is not None else self.exist_decode()) | ||
) | ||
|
||
print( | ||
f"in initialize_forward_meta , self.forward_meta.step_use_cudagraph:{self.forward_meta.step_use_cudagraph}" | ||
) | ||
|
||
# Initialzie attention meta data | ||
for attn_backend in self.attn_backends: | ||
attn_backend.init_attention_metadata(self.forward_meta) | ||
|
@@ -1007,6 +1087,165 @@ def initialize_attn_backend(self) -> None: | |
|
||
self.attn_backends.append(attn_backend) | ||
|
||
def _dummy_run_prefill( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 整个dummy run都需要重写吗?是不是重写个_dummy_prefill_inputs_prefill就够了 |
||
self, | ||
num_tokens: paddle.Tensor, | ||
batch_size: paddle.Tensor, | ||
expected_decode_len: int = 1, | ||
in_capturing: bool = False, | ||
) -> paddle.Tensor: | ||
""" | ||
Use dummy inputs to run before formal execution. | ||
Args: | ||
num_tokens: | ||
expected_decode_len: Expected number of tokens generated | ||
in_capturing: Is cuda graph in capturing state | ||
""" | ||
print("####### Get in _dummy_run_prefill ######") | ||
self._dummy_prefill_inputs_prefill( | ||
num_tokens=num_tokens, | ||
seq_length=num_tokens, | ||
expected_decode_len=expected_decode_len, | ||
) | ||
if self.speculative_method in ["mtp"]: | ||
self.proposer._dummy_prefill_inputs_prefill( | ||
num_tokens=num_tokens, | ||
seq_length=num_tokens, | ||
expected_decode_len=expected_decode_len, | ||
) | ||
while True: | ||
|
||
# 1. Initialize forward meta and attention meta data | ||
self._prepare_inputs() | ||
|
||
# 2. Padding inputs for cuda graph | ||
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph | ||
print( | ||
f"in _dummy_run_prefill ,self.forward_meta.step_use_cudagraph:{self.forward_meta.step_use_cudagraph}" | ||
) | ||
ids_remove_padding = self.share_inputs["ids_remove_padding"] | ||
print(f"in _dummy_run_prefill ,ids_remove_padding:{ids_remove_padding}") | ||
self.padding_cudagraph_inputs() | ||
|
||
# 3. Run model | ||
if self.enable_mm: | ||
model_output = self.model( | ||
self.share_inputs["ids_remove_padding"], | ||
self.share_inputs["image_features"], | ||
self.forward_meta, | ||
) | ||
hidden_states = model_output | ||
else: | ||
model_output = self.model( | ||
ids_remove_padding=self.share_inputs["ids_remove_padding"], | ||
forward_meta=self.forward_meta, | ||
) | ||
|
||
hidden_states = rebuild_padding( | ||
model_output, | ||
self.share_inputs["cum_offsets"], | ||
self.share_inputs["seq_lens_this_time"], | ||
self.share_inputs["seq_lens_decoder"], | ||
self.share_inputs["seq_lens_encoder"], | ||
( | ||
self.share_inputs["output_padding_offset"] if self.speculative_decoding else None | ||
), # speculative decoding requires | ||
self.parallel_config.max_model_len, | ||
) | ||
|
||
# 4. Execute spec decode | ||
logits = self.model.compute_logits(hidden_states) | ||
|
||
if not self.speculative_decoding: | ||
set_value_by_flags_and_idx( | ||
self.share_inputs["pre_ids"], | ||
self.share_inputs["input_ids"], | ||
self.share_inputs["seq_lens_this_time"], | ||
self.share_inputs["seq_lens_encoder"], | ||
self.share_inputs["seq_lens_decoder"], | ||
self.share_inputs["step_idx"], | ||
self.share_inputs["stop_flags"], | ||
) | ||
sampler_output = self.sampler(logits, self.sampling_metadata) | ||
if self.parallel_config.tensor_parallel_size > 1: | ||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) | ||
else: | ||
self.sampler( | ||
logits, | ||
self.sampling_metadata, | ||
self.parallel_config.max_model_len, | ||
self.share_inputs, | ||
) | ||
sampler_output = None | ||
if self.parallel_config.tensor_parallel_size > 1: | ||
paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0) | ||
paddle.distributed.broadcast(self.share_inputs["accept_num"], 0) | ||
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0) | ||
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0) | ||
|
||
# 5. post process | ||
model_output_data = ModelOutputData( | ||
next_tokens=self.share_inputs["next_tokens"], | ||
stop_flags=self.share_inputs["stop_flags"], | ||
step_idx=self.share_inputs["step_idx"], | ||
max_dec_len=self.share_inputs["max_dec_len"], | ||
pre_ids=self.share_inputs["pre_ids"], | ||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"], | ||
eos_token_id=self.share_inputs["eos_token_id"], | ||
not_need_stop=self.share_inputs["not_need_stop"], | ||
input_ids=self.share_inputs["input_ids"], | ||
stop_nums=self.share_inputs["stop_nums"], | ||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"], | ||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"], | ||
is_block_step=self.share_inputs["is_block_step"], | ||
full_hidden_states=model_output, | ||
msg_queue_id=self.parallel_config.msg_queue_id, | ||
mp_rank=self.local_rank, | ||
use_ep=self.parallel_config.use_ep, | ||
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), | ||
actual_draft_token_num=( | ||
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None | ||
), | ||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), | ||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), | ||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), | ||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), | ||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), | ||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), | ||
stop_token_ids=self.share_inputs["stop_seqs"], | ||
stop_seqs_len=self.share_inputs["stop_seqs_len"], | ||
) | ||
|
||
post_process( | ||
sampler_output=sampler_output, | ||
model_output=model_output_data, | ||
share_inputs=self.share_inputs, | ||
block_size=self.cache_config.block_size, | ||
speculative_decoding=self.speculative_decoding, | ||
skip_save_output=True, | ||
) | ||
|
||
if self.speculative_decoding: | ||
if self.speculative_method == "mtp": | ||
self.proposer.run(full_hidden_states=model_output) | ||
else: | ||
self.proposer.run(share_inputs=self.share_inputs) | ||
|
||
# 7. Updata 'infer_seed' and step_cuda() | ||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment) | ||
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED | ||
step_cuda( | ||
self.share_inputs, | ||
self.cache_config.block_size, | ||
self.cache_config.enc_dec_block_num, | ||
self.speculative_config, | ||
self.cache_config.enable_prefix_caching, | ||
) | ||
|
||
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: | ||
break | ||
# break | ||
|
||
def _dummy_run( | ||
self, | ||
num_tokens: paddle.Tensor, | ||
|
@@ -1243,6 +1482,36 @@ def capture_model(self) -> None: | |
time_after_capture = time.perf_counter() | ||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") | ||
|
||
if self.cudagraph_capture_prefill: | ||
self.capture_model_prefill() | ||
|
||
Comment on lines
+1485
to
+1487
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个加在gpu_worker 里吧 和原本的 gpu_model_runner 是平级的 |
||
def capture_model_prefill(self) -> None: | ||
""" | ||
Trigger CUDA Graph capture for all shapes in cuda graph capture list | ||
""" | ||
if not self.use_cudagraph: | ||
logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig") | ||
return | ||
if not self.cudagraph_capture_prefill: | ||
logger.info("Cuda graph prefill capture is disabled.") | ||
return | ||
time_before_capture = time.perf_counter() | ||
expected_decode_len = 1 | ||
# capture_sizes = self.cudagraph_capture_sizes.copy() | ||
capture_token_nums = [256] | ||
fixed_batch_size = 1 | ||
for token_nums in sorted(capture_token_nums, reverse=True): | ||
self._dummy_run_prefill( | ||
num_tokens=token_nums, | ||
batch_size=1, | ||
in_capturing=True, | ||
expected_decode_len=expected_decode_len, | ||
) | ||
logger.info(f"Warm up the model with the token_nums:{token_nums}, num tokens:{expected_decode_len}") | ||
|
||
time_after_capture = time.perf_counter() | ||
logger.info("now get in capture_model_prefill") | ||
|
||
@sot_warmup_guard(True) | ||
def sot_warmup(self) -> None: | ||
start_time = time.perf_counter() | ||
|
@@ -1316,10 +1585,13 @@ class at the server level, which is too granular for ModelRunner. | |
) | ||
hidden_states = model_output | ||
else: | ||
# print("传递给model的seq_lens_this_time",self.forward_meta.seq_lens_this_time.shape) | ||
# print("input_ids",self.forward_meta.input_ids.shape) | ||
model_output = self.model( | ||
ids_remove_padding=self.share_inputs["ids_remove_padding"], | ||
forward_meta=self.forward_meta, | ||
) | ||
# print("self.share_inputs['seq_lens_this_time']",self.share_inputs["seq_lens_this_time"]) | ||
hidden_states = rebuild_padding( | ||
model_output, | ||
self.share_inputs["cum_offsets"], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分逻辑能确保 input_length 等于想要捕获的 num_tokens 吗