@@ -119,6 +119,162 @@ Array<int64_t> BatchPrefillWithKVCachePlan(
119119 return Array (plan_info.ToVector ());
120120}
121121
122+ void BatchPrefillWithRaggedKVCacheRun (TensorView float_workspace_buffer,
123+ TensorView int_workspace_buffer,
124+ Array<int64_t > plan_info_vec,
125+ TensorView q,
126+ TensorView k,
127+ TensorView v,
128+ TensorView qo_indptr,
129+ TensorView kv_indptr,
130+ TensorView o,
131+ Optional<TensorView> maybe_lse,
132+ int64_t mask_mode_code,
133+ int64_t layout,
134+ int64_t window_left,
135+ bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
136+ PrefillPlanInfo plan_info;
137+ plan_info.FromVector (
138+ std::vector<int64_t >(plan_info_vec.begin (), plan_info_vec.end ()));
139+ QKVLayout kv_layout = static_cast <QKVLayout>(layout);
140+
141+ int64_t num_qo_heads = q->shape [1 ];
142+ int64_t head_dim_qk = q->shape [2 ];
143+ int64_t num_kv_heads =
144+ (kv_layout == QKVLayout::kNHD ) ? k->shape [1 ] : k->shape [0 ];
145+ uint32_t q_stride_n = q->strides [0 ], q_stride_h = q->strides [1 ], k_stride_n,
146+ k_stride_h, v_stride_n, v_stride_h;
147+ if (kv_layout == QKVLayout::kNHD ) {
148+ k_stride_n = k->strides [0 ];
149+ k_stride_h = k->strides [1 ];
150+ v_stride_n = v->strides [0 ];
151+ v_stride_h = v->strides [1 ];
152+ } else {
153+ k_stride_h = k->strides [0 ];
154+ k_stride_n = k->strides [1 ];
155+ v_stride_h = v->strides [0 ];
156+ v_stride_n = v->strides [1 ];
157+ }
158+
159+ if (maybe_lse.has_value ()) {
160+ const auto & lse = *maybe_lse;
161+ TVM_FFI_ICHECK_EQ (lse->shape [0 ], q->shape [0 ]);
162+ TVM_FFI_ICHECK_EQ (lse->shape [1 ], q->shape [1 ]);
163+ }
164+
165+ void * float_buffer_ptr = float_workspace_buffer->data ;
166+ void * int_buffer_ptr = int_workspace_buffer->data ;
167+
168+ const MaskMode mask_mode = static_cast <MaskMode>(mask_mode_code);
169+
170+ cudaSetDevice (float_workspace_buffer->device .device_id );
171+ const cudaStream_t stream = get_stream (float_workspace_buffer->device );
172+
173+ DISPATCH_context (
174+ DTypeQ,
175+ DTypeKV,
176+ DTypeO,
177+ IdType,
178+ MASK_MODE,
179+ HEAD_DIM_QK,
180+ HEAD_DIM_VO,
181+ POS_ENCODING_MODE,
182+ USE_SLIDING_WINDOW,
183+ USE_LOGITS_SOFT_CAP,
184+ USE_FP16_QK_REDUCTION,
185+ AttentionVariant,
186+ RaggedParams,
187+ PagedParams,
188+ [&] {
189+ RaggedParams params;
190+
191+ params.q = static_cast <DTypeQ*>(q->data );
192+ params.k = static_cast <DTypeKV*>(k->data );
193+ params.v = static_cast <DTypeKV*>(v->data );
194+ params.o = static_cast <DTypeO*>(o->data );
195+ params.lse = maybe_lse.has_value ()
196+ ? static_cast <float *>(maybe_lse.value ()->data )
197+ : nullptr ;
198+ params.q_indptr = static_cast <IdType*>(qo_indptr->data );
199+ params.kv_indptr = static_cast <IdType*>(kv_indptr->data );
200+ params.num_qo_heads = num_qo_heads;
201+ params.num_kv_heads = num_kv_heads;
202+ params.group_size = uint_fastdiv (num_qo_heads / num_kv_heads);
203+ params.q_stride_n = q_stride_n;
204+ params.q_stride_h = q_stride_h;
205+ params.k_stride_n = k_stride_n;
206+ params.k_stride_h = k_stride_h;
207+ params.v_stride_n = v_stride_n;
208+ params.v_stride_h = v_stride_h;
209+ params.window_left = window_left;
210+
211+ params.request_indices = nullptr ;
212+ params.qo_tile_indices = nullptr ;
213+ params.kv_tile_indices = nullptr ;
214+ params.merge_indptr = nullptr ;
215+ params.o_indptr = nullptr ;
216+ params.kv_chunk_size_ptr = nullptr ;
217+ params.block_valid_mask = nullptr ;
218+ params.total_num_rows = nullptr ;
219+ params.max_total_num_rows = 0 ;
220+ params.padded_batch_size = 0 ;
221+ params.partition_kv = false ;
222+
223+ ADDITIONAL_PARAMS_SETTER
224+
225+ DTypeO* tmp_v = nullptr ;
226+ float * tmp_s = nullptr ;
227+
228+ params.request_indices = GetPtrFromBaseOffset<IdType>(
229+ int_buffer_ptr, plan_info.request_indices_offset );
230+ params.qo_tile_indices = GetPtrFromBaseOffset<IdType>(
231+ int_buffer_ptr, plan_info.qo_tile_indices_offset );
232+ params.kv_tile_indices = GetPtrFromBaseOffset<IdType>(
233+ int_buffer_ptr, plan_info.kv_tile_indices_offset );
234+ params.o_indptr = GetPtrFromBaseOffset<IdType>(
235+ int_buffer_ptr, plan_info.o_indptr_offset );
236+ params.kv_chunk_size_ptr = GetPtrFromBaseOffset<IdType>(
237+ int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset );
238+ if (plan_info.split_kv ) {
239+ params.merge_indptr = GetPtrFromBaseOffset<IdType>(
240+ int_buffer_ptr, plan_info.merge_indptr_offset );
241+ tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr,
242+ plan_info.v_offset );
243+ tmp_s =
244+ GetPtrFromBaseOffset<float >(float_buffer_ptr, plan_info.s_offset );
245+ if (plan_info.enable_cuda_graph ) {
246+ params.block_valid_mask = GetPtrFromBaseOffset<bool >(
247+ int_buffer_ptr, plan_info.block_valid_mask_offset );
248+ }
249+ }
250+ params.padded_batch_size = plan_info.padded_batch_size ;
251+ params.max_total_num_rows = plan_info.total_num_rows ;
252+ if (plan_info.enable_cuda_graph ) {
253+ params.total_num_rows = GetPtrFromBaseOffset<uint32_t >(
254+ int_buffer_ptr, plan_info.total_num_rows_offset );
255+ }
256+
257+ cudaError_t status = cudaSuccess;
258+
259+ DISPATCH_CTA_TILE_Q (plan_info.cta_tile_q , CTA_TILE_Q, {
260+ status = flashinfer::BatchPrefillWithRaggedKVCacheDispatched<
261+ CTA_TILE_Q,
262+ HEAD_DIM_QK,
263+ HEAD_DIM_VO,
264+ POS_ENCODING_MODE,
265+ /* use_fp16_qk_reduction=*/ USE_FP16_QK_REDUCTION,
266+ MASK_MODE,
267+ AttentionVariant,
268+ RaggedParams>(params, tmp_v, tmp_s, enable_pdl, stream);
269+ });
270+
271+ TVM_FFI_ICHECK (status == cudaSuccess)
272+ << " BatchPrefillWithRaggedKVCache failed with error "
273+ << cudaGetErrorString (status);
274+ return true ;
275+ });
276+ }
277+
122278void BatchPrefillWithPagedKVCacheRun (TensorView float_workspace_buffer,
123279 TensorView int_workspace_buffer,
124280 Array<int64_t > plan_info_vec,
@@ -287,4 +443,58 @@ void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer,
287443 });
288444}
289445
446+ void batch_prefill (torch::Tensor& float_workspace_buffer,
447+ torch::Tensor& int_workspace_buffer,
448+ torch::Tensor& page_locked_int_workspace_buffer,
449+ const torch::Tensor& query,
450+ const torch::Tensor& key,
451+ const torch::Tensor& value,
452+ const torch::Tensor& q_cu_seq_lens,
453+ const torch::Tensor& kv_cu_seq_lens,
454+ int64_t window_size_left,
455+ torch::Tensor& output,
456+ std::optional<torch::Tensor>& output_lse,
457+ bool enable_cuda_graph,
458+ bool enable_pdl) {
459+ torch::Tensor kv_len_arr =
460+ kv_cu_seq_lens.slice (0 , 1 ) - kv_cu_seq_lens.slice (0 , 0 , -1 );
461+ const int64_t total_num_rows = q_cu_seq_lens[-1 ].item <int64_t >();
462+ const int64_t batch_size = q_cu_seq_lens.size (0 ) - 1 ;
463+
464+ Array<int64_t > plan_info_vec = BatchPrefillWithKVCachePlan (
465+ float_workspace_buffer,
466+ int_workspace_buffer,
467+ page_locked_int_workspace_buffer,
468+ q_cu_seq_lens,
469+ kv_cu_seq_lens,
470+ kv_len_arr,
471+ total_num_rows,
472+ batch_size,
473+ query.size (-1 ), // num_qo_heads
474+ key.size (-1 ), // num_kv_heads
475+ /* page_size=*/ 1 ,
476+ enable_cuda_graph, // enable_cuda_graph
477+ query.size (1 ), // head_dim_qk
478+ query.size (1 ), // head_dim_vo
479+ /* causal=*/ true ,
480+ window_size_left,
481+ /* fixed_split_size=*/ -1 ,
482+ /* disable_split_kv=*/ false ); // disable_split_kv
483+
484+ BatchPrefillWithRaggedKVCacheRun (float_workspace_buffer,
485+ int_workspace_buffer,
486+ plan_info_vec,
487+ query,
488+ key,
489+ value,
490+ q_cu_seq_lens,
491+ kv_cu_seq_lens,
492+ output,
493+ output_lse,
494+ /* mask_mode_code=CAUSAL*/ 1 ,
495+ /* layout=*/ 0 ,
496+ window_size_left,
497+ enable_pdl);
498+ }
499+
290500} // namespace xllm::kernel::cuda
0 commit comments