Skip to content

draft for enable prefill in cudagraph #3354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def __init__(
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.cudagraph_capture_prefill: bool = False
"""Now cudagraph only capture decode, whether to capture prefill """
self.full_cuda_graph: bool = True

self.max_capture_size: int = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def __call__(self, **kwargs):
new_grpah.capture_begin()
output = entry.runnable(**kwargs)
new_grpah.capture_end()
new_grpah.print_to_dot_files(
f"/root/paddlejob/workspace/env_run/output/liujundong01/FastDeploy/debug/lazy_capture_bsz{batch_size}",
1 << 0,
)

# Store output buffer
entry.cuda_graph = new_grpah
Expand Down
272 changes: 272 additions & 0 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.cudagraph_capture_prefill = self.graph_opt_config.cudagraph_capture_prefill

# Initialize share inputs
self._init_share_inputs(self.parallel_config.max_num_seqs)
Expand Down Expand Up @@ -166,6 +167,15 @@ def exist_prefill(self):
else:
return 0

def exist_decode(self):
"""
check whether decode stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0:
return 1
else:
return 0

def _init_speculative_proposer(self):
"""
Init speculative proposer
Expand Down Expand Up @@ -541,6 +551,54 @@ def get_attr_from_request(request, attr, default_value=None):
if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)

def _dummy_prefill_inputs_prefill(self, num_tokens: int, seq_length: int, expected_decode_len: int):
"""Set dummy prefill inputs to share_inputs"""
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token

batch_size = 1
max_dec_len = expected_decode_len + 1
full_length = min(
num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len,
)

# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
if self.fd_config.parallel_config.enable_expert_parallel:
full_length = min(full_length, 32)

input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
Comment on lines +560 to +573
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分逻辑能确保 input_length 等于想要捕获的 num_tokens 吗


for i in range(batch_size):
idx = i
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array(
[2] * self.model_config.eos_tokens_lens, dtype="int64"
).reshape(-1, 1)
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = 0
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len
self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len
self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["temperature"][idx : idx + 1] = 1

self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length

self.share_inputs["encoder_block_lens"][idx : idx + 1] = block_num
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
idx * block_num, (idx + 1) * block_num, 1
)
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer

def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to share_inputs"""
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
Expand Down Expand Up @@ -909,6 +967,28 @@ def initialize_forward_meta(self):
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_lens_encoder 这个tensor 指针会变,把get block shape kernel 输出的另外几个tensor 也打印出来看下

)

if self.cudagraph_capture_prefill:
only_prefill_batch = True
decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
# 收集所有 worker 的状态
only_prefill_batch_list = []
decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
only_prefill_batch = all(only_prefill_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "prefill" if only_prefill_batch else "decode"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要改 moe_phase.phase


self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and self.cudagraph_capture_prefill
and only_prefill_batch
and not (decode_exists if decode_exists is not None else self.exist_decode())
)

print(
f"in initialize_forward_meta , self.forward_meta.step_use_cudagraph:{self.forward_meta.step_use_cudagraph}"
)

# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
Expand Down Expand Up @@ -1007,6 +1087,165 @@ def initialize_attn_backend(self) -> None:

self.attn_backends.append(attn_backend)

def _dummy_run_prefill(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整个dummy run都需要重写吗?是不是重写个_dummy_prefill_inputs_prefill就够了

self,
num_tokens: paddle.Tensor,
batch_size: paddle.Tensor,
expected_decode_len: int = 1,
in_capturing: bool = False,
) -> paddle.Tensor:
"""
Use dummy inputs to run before formal execution.
Args:
num_tokens:
expected_decode_len: Expected number of tokens generated
in_capturing: Is cuda graph in capturing state
"""
print("####### Get in _dummy_run_prefill ######")
self._dummy_prefill_inputs_prefill(
num_tokens=num_tokens,
seq_length=num_tokens,
expected_decode_len=expected_decode_len,
)
if self.speculative_method in ["mtp"]:
self.proposer._dummy_prefill_inputs_prefill(
num_tokens=num_tokens,
seq_length=num_tokens,
expected_decode_len=expected_decode_len,
)
while True:

# 1. Initialize forward meta and attention meta data
self._prepare_inputs()

# 2. Padding inputs for cuda graph
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
print(
f"in _dummy_run_prefill ,self.forward_meta.step_use_cudagraph:{self.forward_meta.step_use_cudagraph}"
)
ids_remove_padding = self.share_inputs["ids_remove_padding"]
print(f"in _dummy_run_prefill ,ids_remove_padding:{ids_remove_padding}")
self.padding_cudagraph_inputs()

# 3. Run model
if self.enable_mm:
model_output = self.model(
self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"],
self.forward_meta,
)
hidden_states = model_output
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)

hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"],
(
self.share_inputs["output_padding_offset"] if self.speculative_decoding else None
), # speculative decoding requires
self.parallel_config.max_model_len,
)

# 4. Execute spec decode
logits = self.model.compute_logits(hidden_states)

if not self.speculative_decoding:
set_value_by_flags_and_idx(
self.share_inputs["pre_ids"],
self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
else:
self.sampler(
logits,
self.sampling_metadata,
self.parallel_config.max_model_len,
self.share_inputs,
)
sampler_output = None
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0)
paddle.distributed.broadcast(self.share_inputs["accept_num"], 0)
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0)
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0)

# 5. post process
model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],
stop_flags=self.share_inputs["stop_flags"],
step_idx=self.share_inputs["step_idx"],
max_dec_len=self.share_inputs["max_dec_len"],
pre_ids=self.share_inputs["pre_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
eos_token_id=self.share_inputs["eos_token_id"],
not_need_stop=self.share_inputs["not_need_stop"],
input_ids=self.share_inputs["input_ids"],
stop_nums=self.share_inputs["stop_nums"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
full_hidden_states=model_output,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep,
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
actual_draft_token_num=(
self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None
),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)

post_process(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
speculative_decoding=self.speculative_decoding,
skip_save_output=True,
)

if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
else:
self.proposer.run(share_inputs=self.share_inputs)

# 7. Updata 'infer_seed' and step_cuda()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_cuda(
self.share_inputs,
self.cache_config.block_size,
self.cache_config.enc_dec_block_num,
self.speculative_config,
self.cache_config.enable_prefix_caching,
)

if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
break
# break

def _dummy_run(
self,
num_tokens: paddle.Tensor,
Expand Down Expand Up @@ -1243,6 +1482,36 @@ def capture_model(self) -> None:
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")

if self.cudagraph_capture_prefill:
self.capture_model_prefill()

Comment on lines +1485 to +1487
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个加在gpu_worker 里吧 和原本的 gpu_model_runner 是平级的

def capture_model_prefill(self) -> None:
"""
Trigger CUDA Graph capture for all shapes in cuda graph capture list
"""
if not self.use_cudagraph:
logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig")
return
if not self.cudagraph_capture_prefill:
logger.info("Cuda graph prefill capture is disabled.")
return
time_before_capture = time.perf_counter()
expected_decode_len = 1
# capture_sizes = self.cudagraph_capture_sizes.copy()
capture_token_nums = [256]
fixed_batch_size = 1
for token_nums in sorted(capture_token_nums, reverse=True):
self._dummy_run_prefill(
num_tokens=token_nums,
batch_size=1,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(f"Warm up the model with the token_nums:{token_nums}, num tokens:{expected_decode_len}")

time_after_capture = time.perf_counter()
logger.info("now get in capture_model_prefill")

@sot_warmup_guard(True)
def sot_warmup(self) -> None:
start_time = time.perf_counter()
Expand Down Expand Up @@ -1316,10 +1585,13 @@ class at the server level, which is too granular for ModelRunner.
)
hidden_states = model_output
else:
# print("传递给model的seq_lens_this_time",self.forward_meta.seq_lens_this_time.shape)
# print("input_ids",self.forward_meta.input_ids.shape)
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
# print("self.share_inputs['seq_lens_this_time']",self.share_inputs["seq_lens_this_time"])
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
Expand Down
Loading