-
Notifications
You must be signed in to change notification settings - Fork 592
draft for enable prefill in cudagraph #3354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Thanks for your contribution! |
if self.cudagraph_capture_prefill: | ||
self.capture_model_prefill() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个加在gpu_worker 里吧 和原本的 gpu_model_runner 是平级的
@@ -1007,6 +1087,165 @@ def initialize_attn_backend(self) -> None: | |||
|
|||
self.attn_backends.append(attn_backend) | |||
|
|||
def _dummy_run_prefill( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整个dummy run都需要重写吗?是不是重写个_dummy_prefill_inputs_prefill就够了
decode_exists = self.exist_decode() | ||
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists) | ||
only_prefill_batch = all(only_prefill_batch_list) | ||
self.fd_config.parallel_config.moe_phase.phase = "prefill" if only_prefill_batch else "decode" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要改 moe_phase.phase
full_length = min( | ||
num_tokens // batch_size, | ||
self.parallel_config.max_model_len - max_dec_len, | ||
) | ||
|
||
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. | ||
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. | ||
if self.fd_config.parallel_config.enable_expert_parallel: | ||
full_length = min(full_length, 32) | ||
|
||
input_length = int(full_length * self.cache_config.kv_cache_ratio) | ||
block_num = ( | ||
input_length + self.cache_config.block_size - 1 | ||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分逻辑能确保 input_length 等于想要捕获的 num_tokens 吗
@@ -909,6 +967,28 @@ def initialize_forward_meta(self): | |||
and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_lens_encoder 这个tensor 指针会变,把get block shape kernel 输出的另外几个tensor 也打印出来看下
No description provided.