@@ -39,7 +39,9 @@ GraphBase* CudaDevice::getDeviceGraphRunner(const DeviceInitParams& params,
3939}
4040
4141py::object CudaGraphRunner::normalForward (PyModelInputs& inputs) {
42- return py_forward_method_ (inputs);
42+ auto attn_pyobj = py_attn_pyobj_method_ (inputs, false );
43+ attn_pyobj.attr (" prepare" )(inputs);
44+ return py_forward_method_ (inputs, attn_pyobj);
4345}
4446
4547// column dimension
@@ -97,12 +99,8 @@ void CudaGraphRunner::prepareInputs(PyModelInputs& inputs) {
9799 optimizedCopy (inputs.attention_inputs .padding_offset ,
98100 py_model_inputs_.attention_inputs .padding_offset ,
99101 inputs.attention_inputs .padding_offset .size (0 ) * sizeof (int ));
100- graph_instances_[state_.current_real_graph_bs ].mem_hold_ .params_ptr ->fillParams (
101- inputs.attention_inputs .sequence_lengths ,
102- inputs.attention_inputs .input_lengths ,
103- inputs.attention_inputs .kv_cache_block_id_host ,
104- state_.current_batch_size ,
105- seq_size_per_block_);
102+ auto attn_pyobj = graph_instances_[state_.current_real_graph_bs ].mem_hold_ .attn_pyobj_ ;
103+ attn_pyobj.attr (" prepare_replay" )(inputs);
106104 } else {
107105 auto & py_model_inputs_ = graph_instances_[state_.current_real_graph_seq_len ].mem_hold_ .py_model_inputs_ ;
108106
@@ -343,8 +341,10 @@ void CudaGraphRunner::initCapture() {
343341 capture_mem_hold_ = CaptureMemoryHold (output, inputs, kv_cache_block_offset_, is_prefill_cuda_graph_mode_);
344342 initKernelInternalMemory ();
345343 // get real output data type
344+ auto attn_pyobj = py_attn_pyobj_method_ (capture_mem_hold_.py_model_inputs_ , true );
345+ attn_pyobj.attr (" prepare" )(capture_mem_hold_.py_model_inputs_ );
346346 RTP_LLM_LOG_INFO (" initCapture forward for output datatype start" );
347- auto py_outputs_obj = py_forward_method_ (capture_mem_hold_.py_model_inputs_ );
347+ auto py_outputs_obj = py_forward_method_ (capture_mem_hold_.py_model_inputs_ , attn_pyobj );
348348 RTP_LLM_LOG_INFO (" initCapture forward for output datatype end" );
349349 auto outputs = py_outputs_obj.cast <PyModelOutputs>();
350350 options_cuda_float_ = torch::TensorOptions ()
@@ -382,8 +382,10 @@ void CudaGraphRunner::captureOneGraphInstance(int key, const char* key_type) {
382382 auto inputs = graph_instances_[key].mem_hold_ .py_model_inputs_ ;
383383 // WarmUp twice
384384 RTP_LLM_LOG_INFO (" WarmUp for %s %d start." , key_type, key);
385- py_forward_method_ (inputs);
386- py_forward_method_ (inputs);
385+ auto attn_pyobj = graph_instances_[key].mem_hold_ .attn_pyobj_ ;
386+ attn_pyobj.attr (" prepare" )(inputs);
387+ py_forward_method_ (inputs, attn_pyobj);
388+ py_forward_method_ (inputs, attn_pyobj);
387389 RTP_LLM_LOG_INFO (" WarmUp for %s %d successfully." , key_type, key);
388390
389391 {
@@ -399,7 +401,7 @@ void CudaGraphRunner::captureOneGraphInstance(int key, const char* key_type) {
399401 {
400402 graph.capture_begin ();
401403 CudaGraphCaptureGuard capture_guard;
402- auto py_outputs_obj = py_forward_method_ (inputs);
404+ auto py_outputs_obj = py_forward_method_ (inputs, attn_pyobj );
403405 outputs = py_outputs_obj.cast <PyModelOutputs>();
404406 graph_instances_[key].mem_hold_ .decoder_layer_hidden_states_ .copy_ (outputs.hidden_states );
405407 graph.capture_end ();
0 commit comments