Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rtp_llm/cpp/devices/cuda_impl/CudaFlashInfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ void FlashInferAttnParams::fillParams(torch::Tensor sequence_lengths,
torch::Tensor input_lengths,
torch::Tensor kv_cache_block_id_host,
int batch_size,
int seq_size_per_block) {
int seq_size_per_block,
torch::Tensor prefix_lengths) {
fillFlashInfer(nullptr,
torchTensor2Buffer(sequence_lengths),
torchTensor2Buffer(input_lengths),
Expand Down
3 changes: 2 additions & 1 deletion rtp_llm/cpp/devices/cuda_impl/CudaFlashInfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ struct FlashInferAttnParams: ParamsBase {
torch::Tensor input_lengths,
torch::Tensor kv_cache_block_id_host,
int batch_size,
int seq_size_per_block) override;
int seq_size_per_block,
torch::Tensor prefix_lengths = torch::Tensor()) override;
void fillFlashInfer(const BufferPtr& prefix_lengths_host,
const BufferPtr& sequence_lengths_host,
const BufferPtr& input_lengths_host,
Expand Down
2 changes: 2 additions & 0 deletions rtp_llm/cpp/devices/cuda_impl/CudaGraphDecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ void CudaGraphRunner::captureDecode() {
inputs.attention_inputs.context_total_kv_length = bs * (max_input_len + max_prefix_len);

graph_instances_[bs].mem_hold_ = createCaptureMemoryHold(inputs, bs * num_tokens_per_bs_);
graph_instances_[bs].mem_hold_.attn_pyobj_ =
py_attn_pyobj_method_(graph_instances_[bs].mem_hold_.py_model_inputs_, true);
captureDecodeOneBatchSize(bs);
replayAndSyncCheck(bs, "batch size");
RTP_LLM_LOG_INFO("capture success for batch size: %d", bs);
Expand Down
2 changes: 2 additions & 0 deletions rtp_llm/cpp/devices/cuda_impl/CudaGraphPrefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void CudaGraphRunner::capturePrefill() {
inputs.bert_embedding_inputs.combo_tokens_type_ids.slice(0, 0, seq_len);
}
graph_instances_[seq_len].mem_hold_ = createCaptureMemoryHold(inputs, max_bs_ * num_tokens_per_bs_);
graph_instances_[seq_len].mem_hold_.attn_pyobj_ =
py_attn_pyobj_method_(graph_instances_[seq_len].mem_hold_.py_model_inputs_, true);
graph_instances_[seq_len].mem_hold_.decoder_layer_hidden_states_ =
graph_instances_[seq_len].mem_hold_.decoder_layer_hidden_states_.slice(0, 0, seq_len);
capturePrefillOneSeqLen(seq_len);
Expand Down
35 changes: 14 additions & 21 deletions rtp_llm/cpp/devices/cuda_impl/CudaGraphRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void optimizedCopyAsync(const torch::Tensor& src, torch::Tensor& dst, size_t siz
}

py::object CudaGraphRunner::normalForward(PyModelInputs& inputs) {
return py_forward_method_(inputs);
auto attn_pyobj = py_attn_pyobj_method_(inputs, false);
return py_forward_method_(inputs, attn_pyobj);
}

// column dimension
Expand Down Expand Up @@ -126,20 +127,15 @@ void CudaGraphRunner::prepareInputs(PyModelInputs& inputs) {

copySmallerIntoLarger(inputs.attention_inputs.kv_cache_block_id_device,
py_model_inputs_.attention_inputs.kv_cache_block_id_device);

optimizedCopyAsync(inputs.attention_inputs.sequence_lengths_plus_1_d,
py_model_inputs_.attention_inputs.sequence_lengths_plus_1_d,
state_.current_batch_size * sizeof(int));
optimizedCopyAsync(inputs.attention_inputs.decode_cu_seqlens_d,
py_model_inputs_.attention_inputs.decode_cu_seqlens_d,
(state_.current_batch_size + 1) * sizeof(int));
if (graph_instances_[state_.current_real_graph_bs].mem_hold_.params_ptr) {
graph_instances_[state_.current_real_graph_bs].mem_hold_.params_ptr->fillParams(
inputs.attention_inputs.sequence_lengths,
inputs.attention_inputs.input_lengths,
inputs.attention_inputs.kv_cache_block_id_host,
state_.current_batch_size,
seq_size_per_block_);
}
auto attn_pyobj = graph_instances_[state_.current_real_graph_bs].mem_hold_.attn_pyobj_;
attn_pyobj.attr("prepare")(inputs.attention_inputs);
} else {
auto& py_model_inputs_ = graph_instances_[state_.current_real_graph_seq_len].mem_hold_.py_model_inputs_;
// clear kv_cache_block_id_device, otherwise it will cause the cache block pollution
Expand Down Expand Up @@ -410,9 +406,11 @@ void CudaGraphRunner::initCapture() {
torch::Tensor output;
capture_mem_hold_ = CaptureMemoryHold(output, inputs, is_prefill_cuda_graph_mode_);
initKernelInternalMemory();
// do warm up here to get stable environment, otherwise it will cause kernel error.
// get real output data type
auto attn_pyobj = py_attn_pyobj_method_(capture_mem_hold_.py_model_inputs_, true);
attn_pyobj.attr("prepare")(capture_mem_hold_.py_model_inputs_.attention_inputs);
RTP_LLM_LOG_INFO("initCapture forward for output datatype start");
py_forward_method_(capture_mem_hold_.py_model_inputs_);
py_forward_method_(capture_mem_hold_.py_model_inputs_, attn_pyobj);
RTP_LLM_LOG_INFO("initCapture forward for output datatype end");
output = torch::zeros({max_num_token_, hidden_size_}, options_cuda_float_);
capture_mem_hold_.setHiddenStates(output);
Expand Down Expand Up @@ -443,8 +441,10 @@ void CudaGraphRunner::captureOneGraphInstance(int key, const char* key_type) {
auto inputs = graph_instances_[key].mem_hold_.py_model_inputs_;
// WarmUp twice
RTP_LLM_LOG_INFO("WarmUp for %s %d start.", key_type, key);
py_forward_method_(inputs);
py_forward_method_(inputs);
auto attn_pyobj = graph_instances_[key].mem_hold_.attn_pyobj_;
attn_pyobj.attr("prepare")(inputs.attention_inputs);
py_forward_method_(inputs, attn_pyobj);
py_forward_method_(inputs, attn_pyobj);
RTP_LLM_LOG_INFO("WarmUp for %s %d successfully.", key_type, key);

{
Expand All @@ -469,18 +469,11 @@ void CudaGraphRunner::captureOneGraphInstance(int key, const char* key_type) {
{
graph.capture_begin();
CudaGraphCaptureGuard capture_guard;
auto py_outputs_obj = py_forward_method_(inputs);
auto py_outputs_obj = py_forward_method_(inputs, attn_pyobj);
outputs = py_outputs_obj.cast<PyModelOutputs>();
graph_instances_[key].mem_hold_.decoder_layer_hidden_states_.copy_(outputs.hidden_states);
graph.capture_end();
}
RTP_LLM_LOG_INFO("Capture for %s %d end.", key_type, key);
if (outputs.params_ptr && outputs.params_ptr->check_recycle()) {
graph_instances_[key].mem_hold_.params_ptr =
ParamsBasePtr(outputs.params_ptr.get(), [&](ParamsBase* ptr) {});
} else {
graph_instances_[key].mem_hold_.params_ptr = outputs.params_ptr;
}

if (enable_cuda_graph_debug_mode_) {
graph.debug_dump(output_dot_filename.c_str());
Expand Down
12 changes: 6 additions & 6 deletions rtp_llm/cpp/devices/cuda_impl/CudaGraphRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class CudaGraphRunner: public GraphBase {
} else {
max_bs_ = params.concurrency_config.concurrency_limit;
}
py_forward_method_ = py_instance_.attr("forward");
py_fill_params_method_ = py_instance_.attr("fill_params");
options_cuda_int32_ = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false);
options_cpu_int32_ = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU).requires_grad(false);
options_cuda_float_ = torch::TensorOptions().dtype(model_data_type).device(torch::kCUDA).requires_grad(false);
py_attn_pyobj_method_ = py_instance_.attr("prepare_fmha_impl");
py_forward_method_ = py_instance_.attr("forward");
options_cuda_int32_ = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false);
options_cpu_int32_ = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU).requires_grad(false);
options_cuda_float_ = torch::TensorOptions().dtype(model_data_type).device(torch::kCUDA).requires_grad(false);
RTP_LLM_LOG_INFO("Initialize CudaGraphRunner with parameters below: \n \
enable_cuda_graph_: %d, max_bs_: %d, enable_cuda_graph_debug_mode_: %d, max_seq_len_: %d, seq_size_per_block_: %d, \
hidden_size_: %d, num_tokens_per_bs_: %d, is_prefill_cuda_graph_mode_: %d",
Expand Down Expand Up @@ -103,7 +103,7 @@ class CudaGraphRunner: public GraphBase {
void initCaptureBertEmbeddingInputs(PyModelInputs& inputs, int max_bs, int max_num_token);
void initCaptureAttentionInputsPost();
py::object py_forward_method_;
py::object py_fill_params_method_;
py::object py_attn_pyobj_method_;
bool enable_cuda_graph_{false};
bool is_prefill_cuda_graph_mode_{false};
at::cuda::CUDAStream capture_stream_;
Expand Down
3 changes: 1 addition & 2 deletions rtp_llm/cpp/devices/cuda_impl/CudaGraphUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ class CaptureMemoryHold {
}

public:
// for attention params
rtp_llm::ParamsBasePtr params_ptr{nullptr};
py::object attn_pyobj_{py::none()};
// for output
at::Tensor decoder_layer_hidden_states_;
// for input
Expand Down
10 changes: 6 additions & 4 deletions rtp_llm/cpp/models/PyWrappedModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,12 @@ GptModelOutputs PyWrappedModel::forward(const GptModelInputs& inputs) {
hidden_states = device_->clone({*torchTensor2Buffer(py_model_outputs.hidden_states)});
} else {
DevicePerfWrapper wrapper(device_, "normal forward");
auto py_model_forward = py_model_.attr("forward");
auto outputs = py_model_forward(py_model_inputs);
py_model_outputs = outputs.cast<PyModelOutputs>();
hidden_states = device_->clone({*torchTensor2Buffer(py_model_outputs.hidden_states)});
auto attn_pyobj = py_model_.attr("prepare_fmha_impl")(py_model_inputs, false);
// attn_pyobj.attr("prepare")(py_model_inputs.attention_inputs);
auto py_model_forward = py_model_.attr("forward");
auto outputs = py_model_forward(py_model_inputs, attn_pyobj);
py_model_outputs = outputs.cast<PyModelOutputs>();
hidden_states = device_->clone({*torchTensor2Buffer(py_model_outputs.hidden_states)});
}

RTP_LLM_LOG_DEBUG("Python object instance forward method called successfully.");
Expand Down
3 changes: 3 additions & 0 deletions rtp_llm/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,9 @@ def _create_config(cls, ckpt_path: str):
DeepSeekV2._from_hf(config, ckpt_path)
return config

def support_cuda_graph(self) -> bool:
return True

def _create_python_model(self) -> Optional[GptModelBase]:
model_config = self.model_config
parallelism_config = self.parallelism_config
Expand Down
14 changes: 1 addition & 13 deletions rtp_llm/models_py/bindings/OpDefs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,6 @@ void registerPyOpDefs(pybind11::module& m) {
pybind11::arg("seq_size_per_block"),
"Fill parameters for CUDA graph execution");

pybind11::class_<MlaParams, std::shared_ptr<MlaParams>, rtp_llm::ParamsBase>(m, "MlaParams")
.def(pybind11::init<>())
.def_readonly("batch_indice", &MlaParams::batch_indice)
.def_readonly("positions", &MlaParams::positions)
.def_readonly("paged_kv_last_page_len", &MlaParams::paged_kv_last_page_len)
.def_readonly("kvlen", &MlaParams::kvlen)
.def_readonly("page_indice", &MlaParams::page_indice)
.def_readonly("reuse_cache_page_indice", &MlaParams::reuse_cache_page_indice)
.def_readonly("decode_page_indptr", &MlaParams::decode_page_indptr)
.def_readonly("prefill_page_indptr", &MlaParams::prefill_page_indptr)
.def_readonly("qo_indptr", &MlaParams::qo_indptr)
.def_readonly("batch_reuse_info_vec", &MlaParams::batch_reuse_info_vec);

pybind11::class_<PyPrefillCudaGaphCopyParams>(m, "PyPrefillCudaGaphCopyParams")
.def(pybind11::init<>())
.def_readonly("cuda_graph_prefill_batch_size", &PyPrefillCudaGaphCopyParams::cuda_graph_prefill_batch_size)
Expand All @@ -67,6 +54,7 @@ void registerPyOpDefs(pybind11::module& m) {
pybind11::class_<PyAttentionInputs>(m, "PyAttentionInputs")
.def(pybind11::init<>())
.def_readwrite("is_prefill", &PyAttentionInputs::is_prefill)
.def_readwrite("is_cuda_graph", &PyAttentionInputs::is_cuda_graph)
.def_readwrite("prefix_lengths", &PyAttentionInputs::prefix_lengths)
.def_readwrite("sequence_lengths", &PyAttentionInputs::sequence_lengths)
.def_readwrite("input_lengths", &PyAttentionInputs::input_lengths)
Expand Down
19 changes: 3 additions & 16 deletions rtp_llm/models_py/bindings/OpDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,6 @@
#include "rtp_llm/models_py/bindings/ParamsBase.h"
#include "rtp_llm/cpp/utils/Logger.h"
namespace torch_ext {
struct MlaParams: public rtp_llm::ParamsBase {
torch::Tensor batch_indice;
torch::Tensor positions;
torch::Tensor paged_kv_last_page_len;
torch::Tensor kvlen;
torch::Tensor page_indice;
torch::Tensor reuse_cache_page_indice;
torch::Tensor decode_page_indptr;
torch::Tensor prefill_page_indptr;
torch::Tensor qo_indptr;
torch::Tensor batch_reuse_info_vec;

// Hidden field to keep FlashInferMlaAttnParams object alive
// This ensures the underlying buffers (buf_d, buf_h) are not deallocated
std::shared_ptr<void> _params_holder;
};

struct KVCache {
torch::Tensor kv_cache_base;
Expand Down Expand Up @@ -102,6 +86,9 @@ struct PyAttentionInputs {
torch::Tensor sequence_lengths_plus_1_d;
torch::Tensor input_lengths_d;
torch::Tensor decode_cu_seqlens_d;

// CUDA Graph mode flag
bool is_cuda_graph = false;
};

struct BertEmbeddingInputs {
Expand Down
3 changes: 2 additions & 1 deletion rtp_llm/models_py/bindings/ParamsBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class ParamsBase {
torch::Tensor input_lengths,
torch::Tensor kv_cache_block_id_host,
int batch_size,
int seq_size_per_block) {};
int seq_size_per_block,
torch::Tensor prefix_lengths = torch::Tensor()) {};
// check whether the parmas can be recycled automatically.
virtual bool check_recycle() {
return false;
Expand Down
35 changes: 35 additions & 0 deletions rtp_llm/models_py/bindings/cuda/DebugKernelOp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "rtp_llm/models_py/bindings/cuda/DebugKernelOp.h"
#include "rtp_llm/cpp/core/Dispatch.h"
#include "rtp_llm/cpp/utils/AssertUtils.h"
#include "rtp_llm/cpp/core/torch_utils/BufferTorchUtils.h"

namespace rtp_llm {

void debugKernel(const torch::Tensor& data,
int64_t start_row,
int64_t start_col,
int64_t m,
int64_t n,
int64_t row_len,
int64_t info_id) {
// Validate input tensor
RTP_LLM_CHECK_WITH_INFO(data.is_cuda(), "Input tensor must be on CUDA device");
RTP_LLM_CHECK_WITH_INFO(data.is_contiguous(), "Input tensor must be contiguous");

// Get CUDA stream
auto stream = c10::cuda::getCurrentCUDAStream(data.get_device());

// Dispatch based on data type
DISPATCH_CUDA_FUNCTION_DATA_TYPE(torchDTypeToDataType(data.dtype()),
invoke_debug_kernel2,
data.data_ptr(),
static_cast<int>(start_row),
static_cast<int>(start_col),
static_cast<int>(m),
static_cast<int>(n),
static_cast<int>(row_len),
static_cast<int>(info_id),
stream);
}

} // namespace rtp_llm
25 changes: 25 additions & 0 deletions rtp_llm/models_py/bindings/cuda/DebugKernelOp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include "rtp_llm/cpp/kernels/unfused_attention_kernels.h"

namespace rtp_llm {

/// @brief Debug kernel to print 2D data blocks
/// @param data Input tensor to debug
/// @param start_row Starting row index
/// @param start_col Starting column index
/// @param m Number of rows to print
/// @param n Number of columns to print
/// @param row_len Length of each row (stride)
/// @param info_id Debug identifier
void debugKernel(const torch::Tensor& data,
int64_t start_row,
int64_t start_col,
int64_t m,
int64_t n,
int64_t row_len,
int64_t info_id);

} // namespace rtp_llm
Loading