Skip to content

Commit a4c4319

Browse files
committed
feat - refactor fmha python in cudagraph mode
1 parent 4303504 commit a4c4319

File tree

21 files changed

+254
-48
lines changed

21 files changed

+254
-48
lines changed

rtp_llm/cpp/devices/cuda_impl/CudaGraphDecode.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ void CudaGraphRunner::captureDecode() {
4444
prepareCaptureInputs(inputs, bs, bs * num_tokens_per_bs_);
4545

4646
graph_instances_[bs].mem_hold_ = createCaptureMemoryHold(inputs, bs * num_tokens_per_bs_);
47+
graph_instances_[bs].mem_hold_.attn_pyobj_ =
48+
py_attn_pyobj_method_(graph_instances_[bs].mem_hold_.py_model_inputs_, true);
4749
captureDecodeOneBatchSize(bs);
4850
replayAndSyncCheck(bs, "batch size");
4951
RTP_LLM_LOG_INFO("capture success for batch size: %d", bs);

rtp_llm/cpp/devices/cuda_impl/CudaGraphRunner.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ GraphBase* CudaDevice::getDeviceGraphRunner(const DeviceInitParams& params,
3939
}
4040

4141
py::object CudaGraphRunner::normalForward(PyModelInputs& inputs) {
42-
return py_forward_method_(inputs);
42+
auto attn_pyobj = py_attn_pyobj_method_(inputs, false);
43+
attn_pyobj.attr("prepare")(inputs);
44+
return py_forward_method_(inputs, attn_pyobj);
4345
}
4446

4547
// column dimension
@@ -97,12 +99,8 @@ void CudaGraphRunner::prepareInputs(PyModelInputs& inputs) {
9799
optimizedCopy(inputs.attention_inputs.padding_offset,
98100
py_model_inputs_.attention_inputs.padding_offset,
99101
inputs.attention_inputs.padding_offset.size(0) * sizeof(int));
100-
graph_instances_[state_.current_real_graph_bs].mem_hold_.params_ptr->fillParams(
101-
inputs.attention_inputs.sequence_lengths,
102-
inputs.attention_inputs.input_lengths,
103-
inputs.attention_inputs.kv_cache_block_id_host,
104-
state_.current_batch_size,
105-
seq_size_per_block_);
102+
auto attn_pyobj = graph_instances_[state_.current_real_graph_bs].mem_hold_.attn_pyobj_;
103+
attn_pyobj.attr("prepare_replay")(inputs);
106104
} else {
107105
auto& py_model_inputs_ = graph_instances_[state_.current_real_graph_seq_len].mem_hold_.py_model_inputs_;
108106

@@ -343,8 +341,10 @@ void CudaGraphRunner::initCapture() {
343341
capture_mem_hold_ = CaptureMemoryHold(output, inputs, kv_cache_block_offset_, is_prefill_cuda_graph_mode_);
344342
initKernelInternalMemory();
345343
// get real output data type
344+
auto attn_pyobj = py_attn_pyobj_method_(capture_mem_hold_.py_model_inputs_, true);
345+
attn_pyobj.attr("prepare")(capture_mem_hold_.py_model_inputs_);
346346
RTP_LLM_LOG_INFO("initCapture forward for output datatype start");
347-
auto py_outputs_obj = py_forward_method_(capture_mem_hold_.py_model_inputs_);
347+
auto py_outputs_obj = py_forward_method_(capture_mem_hold_.py_model_inputs_, attn_pyobj);
348348
RTP_LLM_LOG_INFO("initCapture forward for output datatype end");
349349
auto outputs = py_outputs_obj.cast<PyModelOutputs>();
350350
options_cuda_float_ = torch::TensorOptions()
@@ -382,8 +382,10 @@ void CudaGraphRunner::captureOneGraphInstance(int key, const char* key_type) {
382382
auto inputs = graph_instances_[key].mem_hold_.py_model_inputs_;
383383
// WarmUp twice
384384
RTP_LLM_LOG_INFO("WarmUp for %s %d start.", key_type, key);
385-
py_forward_method_(inputs);
386-
py_forward_method_(inputs);
385+
auto attn_pyobj = graph_instances_[key].mem_hold_.attn_pyobj_;
386+
attn_pyobj.attr("prepare")(inputs);
387+
py_forward_method_(inputs, attn_pyobj);
388+
py_forward_method_(inputs, attn_pyobj);
387389
RTP_LLM_LOG_INFO("WarmUp for %s %d successfully.", key_type, key);
388390

389391
{
@@ -399,7 +401,7 @@ void CudaGraphRunner::captureOneGraphInstance(int key, const char* key_type) {
399401
{
400402
graph.capture_begin();
401403
CudaGraphCaptureGuard capture_guard;
402-
auto py_outputs_obj = py_forward_method_(inputs);
404+
auto py_outputs_obj = py_forward_method_(inputs, attn_pyobj);
403405
outputs = py_outputs_obj.cast<PyModelOutputs>();
404406
graph_instances_[key].mem_hold_.decoder_layer_hidden_states_.copy_(outputs.hidden_states);
405407
graph.capture_end();

rtp_llm/cpp/devices/cuda_impl/CudaGraphRunner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class CudaGraphRunner: public GraphBase {
3838
} else {
3939
max_bs_ = params.concurrency_config.concurrency_limit;
4040
}
41+
py_attn_pyobj_method_ = py_instance_.attr("prepare_fmha_impl");
4142
py_forward_method_ = py_instance_.attr("forward");
4243
py_fill_params_method_ = py_instance_.attr("fill_params");
4344
options_cuda_int32_ = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false);
@@ -99,6 +100,7 @@ class CudaGraphRunner: public GraphBase {
99100
void initCaptureBertEmbeddingInputs(PyModelInputs& inputs, int max_bs, int max_num_token);
100101
void initCaptureAttentionInputsPost();
101102
py::object py_forward_method_;
103+
py::object py_attn_pyobj_method_;
102104
py::object py_fill_params_method_;
103105
bool enable_cuda_graph_{false};
104106
bool is_prefill_cuda_graph_mode_{false};

rtp_llm/cpp/devices/cuda_impl/CudaGraphUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class CaptureMemoryHold {
4343
at::Tensor decoder_layer_hidden_states_;
4444
// for input
4545
PyModelInputs py_model_inputs_;
46+
py::object attn_pyobj_;
4647
};
4748

4849
class GraphInstance {

rtp_llm/cpp/models/PyWrappedModel.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,12 @@ GptModelOutputs PyWrappedModel::forward(const GptModelInputs& inputs) {
256256
hidden_states = torchTensor2Buffer(py_model_outputs.hidden_states);
257257
} else {
258258
DevicePerfWrapper wrapper(device_, "normal forward");
259-
auto py_model_forward = py_model_.attr("forward");
260-
auto outputs = py_model_forward(py_model_inputs);
261-
py_model_outputs = outputs.cast<PyModelOutputs>();
262-
hidden_states = device_->clone({*torchTensor2Buffer(py_model_outputs.hidden_states)});
259+
auto attn_pyobj = py_model_.attr("prepare_fmha_impl")(py_model_inputs, false);
260+
attn_pyobj.attr("prepare")(py_model_inputs);
261+
auto py_model_forward = py_model_.attr("forward");
262+
auto outputs = py_model_forward(py_model_inputs, attn_pyobj);
263+
py_model_outputs = outputs.cast<PyModelOutputs>();
264+
hidden_states = device_->clone({*torchTensor2Buffer(py_model_outputs.hidden_states)});
263265
}
264266

265267
RTP_LLM_LOG_DEBUG("Python object instance forward method called successfully.");
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "rtp_llm/models_py/bindings/cuda/DebugKernelOp.h"
2+
#include "rtp_llm/cpp/core/Dispatch.h"
3+
#include "rtp_llm/cpp/utils/AssertUtils.h"
4+
#include "rtp_llm/cpp/core/torch_utils/BufferTorchUtils.h"
5+
6+
namespace rtp_llm {
7+
8+
void DebugKernelOp::forward(const torch::Tensor& data,
9+
int64_t start_row,
10+
int64_t start_col,
11+
int64_t m,
12+
int64_t n,
13+
int64_t row_len,
14+
int64_t info_id) {
15+
// Validate input tensor
16+
RTP_LLM_CHECK_WITH_INFO(data.is_cuda(), "Input tensor must be on CUDA device");
17+
RTP_LLM_CHECK_WITH_INFO(data.is_contiguous(), "Input tensor must be contiguous");
18+
19+
// Get CUDA stream
20+
auto stream = c10::cuda::getCurrentCUDAStream(data.get_device());
21+
22+
// Dispatch based on data type
23+
DISPATCH_CUDA_FUNCTION_DATA_TYPE(torchDTypeToDataType(data.dtype()),
24+
invoke_debug_kernel2,
25+
data.data_ptr(),
26+
static_cast<int>(start_row),
27+
static_cast<int>(start_col),
28+
static_cast<int>(m),
29+
static_cast<int>(n),
30+
static_cast<int>(row_len),
31+
static_cast<int>(info_id),
32+
stream);
33+
}
34+
35+
void registerDebugKernelOp(const py::module& m) {
36+
pybind11::class_<DebugKernelOp>(m, "DebugKernelOp")
37+
.def(pybind11::init<>())
38+
.def("forward",
39+
&DebugKernelOp::forward,
40+
py::arg("data"),
41+
py::arg("start_row") = 0,
42+
py::arg("start_col") = 0,
43+
py::arg("m") = 30,
44+
py::arg("n") = 10,
45+
py::arg("row_len") = 0, // Will use data.sizes()[1] if 0
46+
py::arg("info_id") = 1);
47+
}
48+
49+
} // namespace rtp_llm
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
#include <c10/cuda/CUDAStream.h>
5+
#include "rtp_llm/cpp/kernels/unfused_attention_kernels.h"
6+
7+
namespace rtp_llm {
8+
9+
class DebugKernelOp {
10+
public:
11+
DebugKernelOp() = default;
12+
13+
/// @brief Debug kernel to print 2D data blocks
14+
/// @param data Input tensor to debug
15+
/// @param start_row Starting row index
16+
/// @param start_col Starting column index
17+
/// @param m Number of rows to print
18+
/// @param n Number of columns to print
19+
/// @param row_len Length of each row (stride)
20+
/// @param info_id Debug identifier
21+
void forward(const torch::Tensor& data,
22+
int64_t start_row,
23+
int64_t start_col,
24+
int64_t m,
25+
int64_t n,
26+
int64_t row_len,
27+
int64_t info_id);
28+
};
29+
30+
void registerDebugKernelOp(const py::module& m);
31+
32+
} // namespace rtp_llm

rtp_llm/models_py/bindings/cuda/FusedRopeKVCacheOp.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,19 @@ torch::Tensor FusedRopeKVCacheDecodeOp::forward(const torch::Tensor&
225225
}
226226

227227
void registerFusedRopeKVCacheOp(const py::module& m) {
228-
pybind11::class_<KVBlockArray>(m, "KVBlockArray").def(pybind11::init<>());
229-
pybind11::class_<TRTAttn, std::shared_ptr<TRTAttn>, rtp_llm::ParamsBase>(m, "TRTAttn").def(pybind11::init<>());
228+
pybind11::class_<KVBlockArray>(m, "KVBlockArray")
229+
.def(pybind11::init<>())
230+
.def(
231+
"__cpp_ptr__",
232+
[](KVBlockArray& self) { return reinterpret_cast<uintptr_t>(&self); },
233+
"Get C++ object pointer address");
234+
pybind11::class_<TRTAttn, std::shared_ptr<TRTAttn>, rtp_llm::ParamsBase>(m, "TRTAttn")
235+
.def(pybind11::init<>())
236+
.def_readwrite("kv_cache_offset", &TRTAttn::kv_cache_offset)
237+
.def(
238+
"__cpp_ptr__",
239+
[](TRTAttn& self) { return reinterpret_cast<uintptr_t>(&self); },
240+
"Get C++ object pointer address");
230241
pybind11::class_<FusedRopeKVCachePrefillOp>(m, "FusedRopeKVCachePrefillOp")
231242
.def(pybind11::init<GptInitParameter>(), py::arg("gpt_init_parameter"))
232243
.def("prepare", &FusedRopeKVCachePrefillOp::prepare, py::arg("attn_inputs"))

rtp_llm/models_py/bindings/cuda/RegisterCudaOps.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "rtp_llm/cpp/cuda/cutlass/cutlass_kernels/fp8_group_gemm/fp8_group_gemm.h"
55
#include "rtp_llm/cpp/kernels/scaled_fp8_quant.h"
66
#include "rtp_llm/cpp/kernels/moe/ep_utils.h"
7+
#include "rtp_llm/models_py/bindings/cuda/DebugKernelOp.h"
78

89
namespace rtp_llm {
910

@@ -94,6 +95,7 @@ void registerPyModuleOps(py::module& rtp_ops_m) {
9495

9596
registerBaseCudaBindings(rtp_ops_m);
9697
registerAttnOpBindings(rtp_ops_m);
98+
registerDebugKernelOp(rtp_ops_m);
9799
}
98100

99101
} // namespace rtp_llm

rtp_llm/models_py/bindings/cuda/XQAAttnOp.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ XQAAttnOp::forward(const torch::Tensor& input, std::optional<torch_ext::KVCache>
9090

9191
void registerXQAAttnOp(const py::module& m) {
9292
pybind11::class_<XQAParams, std::shared_ptr<XQAParams>, rtp_llm::ParamsBase>(m, "XQAParams")
93-
.def(pybind11::init<>());
93+
.def(pybind11::init<>())
94+
.def(
95+
"__cpp_ptr__",
96+
[](XQAParams& self) { return reinterpret_cast<uintptr_t>(&self); },
97+
"Get C++ object pointer address")
98+
.def_readwrite("kv_cache_offset", &XQAParams::kv_cache_offset);
9499
pybind11::class_<XQAAttnOp>(m, "XQAAttnOp")
95100
.def(pybind11::init<GptInitParameter>(), py::arg("gpt_init_parameter"))
96101
.def("support", &XQAAttnOp::support, py::arg("attn_inputs").noconvert())

0 commit comments

Comments
 (0)