Skip to content

Commit 5154556

Browse files
authored
[TRTLLM-8803][feat] Add rope and uk-bgemm overlap for mla generation (NVIDIA#8495)
Signed-off-by: yunruis <[email protected]>
1 parent b7798bf commit 5154556

File tree

11 files changed

+802
-76
lines changed

11 files changed

+802
-76
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -870,22 +870,19 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
870870
size_t fmha_scheduler_counter = sizeof(uint32_t);
871871
size_t headDim = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
872872

873-
int const NUM_BUFFERS = 10;
873+
int const NUM_BUFFERS = 7;
874874
size_t workspaces[NUM_BUFFERS];
875-
workspaces[0] = cu_seqlens_size; // cu_q_len
876-
workspaces[1] = cu_seqlens_size; // cu_kv_len
877-
workspaces[2] = fmha_scheduler_counter;
878-
workspaces[3] = mFP8GenerationMLA ? sizeof(float) * 2 : 0; // mla_bmm1_scale_size
879-
workspaces[4] = mFP8GenerationMLA ? sizeof(float) : 0; // mla_bmm2_scale_size
880-
workspaces[5] = mFP8GenerationMLA ? max_num_tokens * size_t(mNumHeads * headDim) : 0; // quant q buffer
875+
workspaces[0] = mIsGenerationMLA ? 0 : cu_seqlens_size; // cu_q_len
876+
workspaces[1] = mIsGenerationMLA ? 0 : cu_seqlens_size; // cu_kv_len
877+
workspaces[2] = mIsGenerationMLA ? 0 : fmha_scheduler_counter;
881878
// The multiCtasKvMode buffers. Each CTA at most handles 256 rows.
882879
// And the seqLenKv is split into at most mMultiProcessorCount tiles.
883-
workspaces[6] = size * 256 * mMultiProcessorCount * headDim;
880+
workspaces[3] = size * 256 * mMultiProcessorCount * headDim;
884881
// The partialSum size.
885-
workspaces[7] = sizeof(float) * 256 * mMultiProcessorCount;
882+
workspaces[4] = sizeof(float) * 256 * mMultiProcessorCount;
886883
// The partialMax size.
887-
workspaces[8] = sizeof(float) * 256 * mMultiProcessorCount;
888-
workspaces[9] = flash_mla_workspace_size;
884+
workspaces[5] = sizeof(float) * 256 * mMultiProcessorCount;
885+
workspaces[6] = flash_mla_workspace_size;
889886

890887
fmha_v2_mla_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
891888
}
@@ -962,6 +959,16 @@ template <typename T>
962959
int AttentionOp::mlaGeneration(
963960
MlaParams<T>& params, EnqueueGenerationParams<T> const& generation_params, cudaStream_t stream)
964961
{
962+
TLLM_CHECK_WITH_INFO(params.seqQOffset != nullptr, "seqQOffset is nullptr.");
963+
TLLM_CHECK_WITH_INFO(params.cache_seq_lens != nullptr, "cache_seq_lens is nullptr.");
964+
TLLM_CHECK_WITH_INFO(params.fmha_tile_counter != nullptr, "fmha_tile_counter is nullptr.");
965+
if (mFP8GenerationMLA)
966+
{
967+
TLLM_CHECK_WITH_INFO(params.quant_q_buf != nullptr, "quant_q_buf is nullptr.");
968+
TLLM_CHECK_WITH_INFO(params.bmm1_scale != nullptr, "bmm1_scale is nullptr.");
969+
TLLM_CHECK_WITH_INFO(params.bmm2_scale != nullptr, "bmm2_scale is nullptr.");
970+
}
971+
965972
int const num_kv_heads = 1;
966973
int const head_size = mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
967974
int32_t const batch_beam = generation_params.beam_width * generation_params.num_requests;
@@ -983,33 +990,8 @@ int AttentionOp::mlaGeneration(
983990
// Workspace pointer shift
984991
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
985992
size_t offset = 0;
986-
987-
size_t const cu_seqlens_size = sizeof(int) * (params.batch_size + 1);
988-
size_t const fmha_scheduler_counter = sizeof(uint32_t);
989-
size_t const mla_bmm1_scale_size = mFP8GenerationMLA ? sizeof(float) * 2 : 0;
990-
size_t const mla_bmm2_scale_size = mFP8GenerationMLA ? sizeof(float) : 0;
991-
size_t const quant_q_buffer_size = mFP8GenerationMLA
992-
? params.acc_q_len * size_t(mNumHeads * (mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim))
993-
: 0;
994-
int* cu_q_seqlens = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
995-
int* cu_kv_seqlens = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
996-
uint32_t* fmha_tile_counter_ptr
997-
= reinterpret_cast<uint32_t*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_scheduler_counter));
998-
float* mla_bmm1_scale_ptr
999-
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, mla_bmm1_scale_size));
1000-
float* mla_bmm2_scale_ptr
1001-
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, mla_bmm2_scale_size));
1002-
void* quant_q_buffer_ptr
1003-
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, quant_q_buffer_size));
1004993
void* scratch_ptr = nextWorkspacePtr(workspace_byte_ptr, offset);
1005994

1006-
params.seqQOffset = cu_q_seqlens;
1007-
params.cu_kv_seqlens = cu_kv_seqlens;
1008-
params.fmha_tile_counter = fmha_tile_counter_ptr;
1009-
params.bmm1_scale = mla_bmm1_scale_ptr;
1010-
params.bmm2_scale = mla_bmm2_scale_ptr;
1011-
params.quant_q_buf = quant_q_buffer_ptr;
1012-
1013995
params.quant_scale_o = generation_params.attention_output_orig_quant;
1014996
params.quant_scale_q = generation_params.kv_scale_orig_quant;
1015997
params.quant_scale_kv = generation_params.kv_scale_orig_quant;
@@ -1018,9 +1000,6 @@ int AttentionOp::mlaGeneration(
10181000
params.host_bmm1_scale
10191001
= 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
10201002

1021-
invokeMLARopeGeneration<T>(params, kv_cache_buffer, stream);
1022-
sync_check_cuda_error(stream);
1023-
10241003
if (generation_params.runtime_perf_knobs)
10251004
{
10261005
int64_t multi_block_mode_val = generation_params.runtime_perf_knobs[0];
@@ -1261,7 +1240,7 @@ int AttentionOp::mlaGeneration(
12611240
XQAParams xqaParams{};
12621241
this->template convertMMHAParamsToXQAParams<T, decltype(kv_cache_buffer)>(
12631242
xqaParams, generation_params, /*forConfigurePlugin=*/false);
1264-
xqaParams.quant_q_buffer_ptr = quant_q_buffer_ptr;
1243+
xqaParams.quant_q_buffer_ptr = params.quant_q_buf;
12651244
xqaParams.q_scaling
12661245
= 1 / (mQScaling * sqrtf((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
12671246
if (mEnableXQA && mXqaDispatcher->shouldUse(xqaParams))
@@ -1303,11 +1282,11 @@ int AttentionOp::mlaGeneration(
13031282

13041283
// fmhaParams.packedMaskPtr = params.fmha_custom_mask;
13051284
fmhaParams.pagedKvCache = kv_cache_buffer;
1306-
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
1285+
fmhaParams.cuQSeqLenPtr = params.seqQOffset;
13071286
fmhaParams.kvSeqLenPtr = params.cache_seq_lens;
1308-
fmhaParams.cuKvSeqLenPtr = cu_kv_seqlens;
1287+
fmhaParams.cuKvSeqLenPtr = params.cu_kv_seqlens;
13091288
fmhaParams.cuMaskRowsPtr = nullptr; // mla not support custorm mask right now
1310-
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
1289+
fmhaParams.tileCounterPtr = params.fmha_tile_counter;
13111290
fmhaParams.scaleBmm1Ptr = reinterpret_cast<float const*>(params.bmm1_scale);
13121291
fmhaParams.scaleBmm2Ptr = reinterpret_cast<float const*>(params.bmm2_scale);
13131292
fmhaParams.stream = stream;

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ void initBindings(nb::module_& m)
6565
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices") = std::nullopt,
6666
nb::arg("sparse_kv_offsets") = std::nullopt, nb::arg("sparse_attn_indices") = std::nullopt,
6767
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_mla_topk") = std::nullopt,
68+
nb::arg("cu_q_seqlens") = std::nullopt, nb::arg("cu_kv_seqlens") = std::nullopt,
69+
nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt,
70+
nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt,
6871
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
6972
}
7073
} // namespace tensorrt_llm::nanobind::thop

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ void initBindings(pybind11::module_& m)
6565
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
6666
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
6767
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_mla_topk") = std::nullopt,
68+
py::arg("cu_q_seqlens") = std::nullopt, py::arg("cu_kv_seqlens") = std::nullopt,
69+
py::arg("fmha_scheduler_counter") = std::nullopt, py::arg("mla_bmm1_scale") = std::nullopt,
70+
py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt,
6871
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
6972
}
7073
} // namespace tensorrt_llm::pybind::thop

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ add_library(
9999
mtpOp.cpp
100100
loraOp.cpp
101101
finegrained_mixed_dtype_gemm_thop.cpp
102-
tinygemm2.cpp)
102+
tinygemm2.cpp
103+
dsv3RopeOp.cpp)
103104
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
104105
target_link_libraries(
105106
th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES}

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ class RunnerBase
8686
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
8787
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
8888
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
89-
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk) const
89+
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
90+
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
91+
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
92+
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const
9093
= 0;
9194
};
9295

@@ -143,7 +146,10 @@ class Runner : public RunnerBase
143146
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
144147
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
145148
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
146-
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk) const override
149+
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
150+
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
151+
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
152+
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const override
147153
{
148154
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
149155
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
@@ -216,6 +222,13 @@ class Runner : public RunnerBase
216222
v_ptr = static_cast<T*>(v->slice(0, token_offset).data_ptr());
217223
mla_params.k_buf = k_ptr;
218224
mla_params.v_buf = v_ptr;
225+
226+
// For generation, helix position is in ropeOp
227+
auto& mla_helix_position_offsets = mla_tensor_params[0];
228+
if (mla_helix_position_offsets.has_value())
229+
{
230+
mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr<int32_t>();
231+
}
219232
}
220233
else
221234
{
@@ -228,6 +241,22 @@ class Runner : public RunnerBase
228241
mla_params.q_pe = static_cast<T*>(q_pe->data_ptr());
229242
mla_params.q_pe_ld = q_pe->strides()[1];
230243
mla_params.q_pe_stride = q_pe->strides()[0];
244+
245+
mla_params.seqQOffset
246+
= cu_q_seqlens.has_value() ? reinterpret_cast<int*>(cu_q_seqlens.value().data_ptr()) : nullptr;
247+
mla_params.cu_kv_seqlens
248+
= cu_kv_seqlens.has_value() ? reinterpret_cast<int*>(cu_kv_seqlens.value().data_ptr()) : nullptr;
249+
mla_params.fmha_tile_counter = fmha_scheduler_counter.has_value()
250+
? reinterpret_cast<uint32_t*>(fmha_scheduler_counter.value().data_ptr())
251+
: nullptr;
252+
mla_params.bmm1_scale = mla_bmm1_scale.has_value()
253+
? reinterpret_cast<float*>(mla_bmm1_scale.value().data_ptr())
254+
: nullptr;
255+
mla_params.bmm2_scale = mla_bmm2_scale.has_value()
256+
? reinterpret_cast<float*>(mla_bmm2_scale.value().data_ptr())
257+
: nullptr;
258+
mla_params.quant_q_buf
259+
= quant_q_buffer.has_value() ? reinterpret_cast<void*>(quant_q_buffer.value().data_ptr()) : nullptr;
231260
}
232261
mla_params.q_buf = attention_input;
233262
mla_params.context_buf = reinterpret_cast<T*>(context_buf);
@@ -239,11 +268,6 @@ class Runner : public RunnerBase
239268
mla_params.meta = op.mMLAParams;
240269

241270
mla_params.workspace = workspace_ptr;
242-
auto& mla_helix_position_offsets = mla_tensor_params[0];
243-
if (mla_helix_position_offsets.has_value())
244-
{
245-
mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr<int32_t>();
246-
}
247271
}
248272

249273
int const* context_lengths_ptr = context_lengths.slice(0, seq_offset).data_ptr<int>();
@@ -565,7 +589,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
565589
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
566590
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
567591
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
568-
std::optional<int64_t> sparse_mla_topk)
592+
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
593+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
594+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
595+
std::optional<torch::Tensor> quant_q_buffer)
569596
{
570597
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
571598
// Use these tensors to infer if the attention is using KV cache
@@ -829,7 +856,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
829856
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
830857
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
831858
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
832-
sparse_mla_topk_value);
859+
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
860+
quant_q_buffer);
833861
}
834862

835863
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -847,7 +875,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
847875
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
848876
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
849877
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
850-
sparse_mla_topk_value);
878+
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
879+
quant_q_buffer);
851880
}
852881

853882
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

cpp/tensorrt_llm/thop/attentionOp.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
6363
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
6464
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
6565
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
66-
std::optional<int64_t> sparse_mla_topk);
66+
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
67+
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
68+
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
69+
std::optional<torch::Tensor> quant_q_buffer);
6770

6871
} // namespace torch_ext

0 commit comments

Comments
 (0)