@@ -86,7 +86,10 @@ class RunnerBase
8686 c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
8787 torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
8888 torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
89- torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk) const
89+ torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
90+ std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
91+ std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
92+ std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const
9093 = 0;
9194};
9295
@@ -143,7 +146,10 @@ class Runner : public RunnerBase
143146 c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
144147 torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
145148 torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
146- torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk) const override
149+ torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
150+ std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
151+ std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
152+ std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const override
147153 {
148154 auto stream = at::cuda::getCurrentCUDAStream (qkv_or_q.get_device ());
149155 T* attention_input = static_cast <T*>(qkv_or_q.slice (0 , token_offset).data_ptr ());
@@ -216,6 +222,13 @@ class Runner : public RunnerBase
216222 v_ptr = static_cast <T*>(v->slice (0 , token_offset).data_ptr ());
217223 mla_params.k_buf = k_ptr;
218224 mla_params.v_buf = v_ptr;
225+
226+ // For generation, helix position is in ropeOp
227+ auto & mla_helix_position_offsets = mla_tensor_params[0 ];
228+ if (mla_helix_position_offsets.has_value ())
229+ {
230+ mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr <int32_t >();
231+ }
219232 }
220233 else
221234 {
@@ -228,6 +241,22 @@ class Runner : public RunnerBase
228241 mla_params.q_pe = static_cast <T*>(q_pe->data_ptr ());
229242 mla_params.q_pe_ld = q_pe->strides ()[1 ];
230243 mla_params.q_pe_stride = q_pe->strides ()[0 ];
244+
245+ mla_params.seqQOffset
246+ = cu_q_seqlens.has_value () ? reinterpret_cast <int *>(cu_q_seqlens.value ().data_ptr ()) : nullptr ;
247+ mla_params.cu_kv_seqlens
248+ = cu_kv_seqlens.has_value () ? reinterpret_cast <int *>(cu_kv_seqlens.value ().data_ptr ()) : nullptr ;
249+ mla_params.fmha_tile_counter = fmha_scheduler_counter.has_value ()
250+ ? reinterpret_cast <uint32_t *>(fmha_scheduler_counter.value ().data_ptr ())
251+ : nullptr ;
252+ mla_params.bmm1_scale = mla_bmm1_scale.has_value ()
253+ ? reinterpret_cast <float *>(mla_bmm1_scale.value ().data_ptr ())
254+ : nullptr ;
255+ mla_params.bmm2_scale = mla_bmm2_scale.has_value ()
256+ ? reinterpret_cast <float *>(mla_bmm2_scale.value ().data_ptr ())
257+ : nullptr ;
258+ mla_params.quant_q_buf
259+ = quant_q_buffer.has_value () ? reinterpret_cast <void *>(quant_q_buffer.value ().data_ptr ()) : nullptr ;
231260 }
232261 mla_params.q_buf = attention_input;
233262 mla_params.context_buf = reinterpret_cast <T*>(context_buf);
@@ -239,11 +268,6 @@ class Runner : public RunnerBase
239268 mla_params.meta = op.mMLAParams ;
240269
241270 mla_params.workspace = workspace_ptr;
242- auto & mla_helix_position_offsets = mla_tensor_params[0 ];
243- if (mla_helix_position_offsets.has_value ())
244- {
245- mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr <int32_t >();
246- }
247271 }
248272
249273 int const * context_lengths_ptr = context_lengths.slice (0 , seq_offset).data_ptr <int >();
@@ -565,7 +589,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
565589 std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
566590 std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
567591 std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
568- std::optional<int64_t > sparse_mla_topk)
592+ std::optional<int64_t > sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
593+ std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
594+ std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
595+ std::optional<torch::Tensor> quant_q_buffer)
569596{
570597 TLLM_LOG_TRACE (" Attention op starts at layer %d" , layer_idx);
571598 // Use these tensors to infer if the attention is using KV cache
@@ -829,7 +856,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
829856 rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
830857 mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
831858 attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
832- sparse_mla_topk_value);
859+ sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
860+ quant_q_buffer);
833861 }
834862
835863 if ((num_generations > 0 ) && (attn_input_type != AttentionInputType::ContextOnly))
@@ -847,7 +875,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
847875 rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
848876 mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
849877 attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
850- sparse_mla_topk_value);
878+ sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
879+ quant_q_buffer);
851880 }
852881
853882 TLLM_LOG_TRACE (" Attention op stops at layer %d" , layer_idx);
0 commit comments