-
Notifications
You must be signed in to change notification settings - Fork 558
Open
Labels
Description
I build flashinfer with aot and use .so file to call batch prefill function, including plan and ragged_run, and here is the code:
namespace cuda{
void batch_prefill(torch::Tensor& float_workspace_buffer,
torch::Tensor& int_workspace_buffer,
torch::Tensor& page_locked_int_workspace_buffer,
const torch::Tensor& query,
const torch::Tensor& key,
const torch::Tensor& value,
const torch::Tensor& q_cu_seq_lens,
const torch::Tensor& kv_cu_seq_lens,
int64_t window_size_left,
torch::Tensor& output,
std::optional<torch::Tensor>& output_lse,
bool enable_cuda_graph) {
std::string uri = get_batch_prefill_uri(/*backend=*/"fa2",
query.scalar_type(),
key.scalar_type(),
output.scalar_type(),
q_cu_seq_lens.scalar_type(),
query.size(-1),
value.size(-1),
/*pos_encoding_mode=*/0,
/*use_sliding_window=*/false,
/*use_logits_soft_cap=*/false,
/*use_fp16_qk_reduction=*/false);
torch::Tensor kv_indptr_host = kv_cu_seq_lens.to(torch::kCPU);
torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU);
torch::Tensor kv_len_arr_host =
kv_indptr_host.slice(0, 1) - kv_indptr_host.slice(0, 0, -1);
const int64_t total_num_rows = qo_indptr_host.size(0);
const int64_t batch_size = q_cu_seq_lens.size(0) - 1;
const double sm_scale = compute_sm_scale(query.size(-1));
// the type of plan_info is ffi::Any
auto plan_info = get_module(uri)->GetFunction("plan").value()(
to_ffi_tensor(float_workspace_buffer),
to_ffi_tensor(int_workspace_buffer),
to_ffi_tensor(page_locked_int_workspace_buffer),
to_ffi_tensor(qo_indptr_host),
to_ffi_tensor(kv_indptr_host),
to_ffi_tensor(kv_len_arr_host),
total_num_rows,
batch_size,
query.size(1), // num_qo_heads
key.size(1), // num_kv_heads
/*page_size=*/1,
enable_cuda_graph,
query.size(-1), // head_dim_qk
value.size(-1), // head_dim_vo
/*causal=*/true, // causal
/*window_size_left=*/1,
/*fixed_split_size=*/-1,
/*disable_split_kv=*/false);
// convert plan_info to ffi::Array
auto prefill_plan_info = plan_info.cast<ffi::Array<int64_t>>();
get_module(uri)
->GetFunction("ragged_run")
.value()(to_ffi_tensor(float_workspace_buffer),
to_ffi_tensor(int_workspace_buffer),
prefill_plan_info,
to_ffi_tensor(query),
to_ffi_tensor(key),
to_ffi_tensor(value),
to_ffi_tensor(q_cu_seq_lens),
to_ffi_tensor(kv_cu_seq_lens),
to_ffi_tensor(output),
ffi::Optional<ffi::Tensor>(),
/*mask_mode_code=CAUSAL*/ 1,
/*kv_layout_code=*/0, // NHD layout
/*window_size_left=*/-1,
support_pdl(),
/*maybe_custom_mask=*/ffi::Optional<ffi::Tensor>(),
/*maybe_mask_indptr=*/ffi::Optional<ffi::Tensor>(),
/*maybe_alibi_slopes=*/ffi::Optional<ffi::Tensor>(),
/*maybe_prefix_len_ptr=*/ffi::Optional<ffi::Tensor>(),
/*maybe_token_pos_in_items_ptr=*/ffi::Optional<ffi::Tensor>(),
/*maybe_max_item_len_ptr=*/ffi::Optional<ffi::Tensor>(),
/*logits_soft_cap=*/0.0,
/*sm_scale=*/sm_scale,
/*rope_rcp_scale=*/1.0,
/*rope_rcp_theta=*/1.0 / 10000.0,
/*token_pos_in_items_len=*/0);
std::cout << "batch_prefill cuda run end" << std::endl;
}
}Here is the how to call cuda::batch_prefill
void batch_prefill(AttentionParams& params) {
std::cout << "batch_prefill cuda in ops_api" << std::endl;
cuda::batch_prefill(params.float_workspace_buffer,
params.int_workspace_buffer,
params.page_locked_int_workspace_buffer,
params.query,
params.key,
params.value,
params.q_cu_seq_lens,
params.kv_cu_seq_lens,
params.window_size_left,
params.output,
params.output_lse,
params.enable_cuda_graph);
std::cout << "batch_prefill cuda end in ops_api" << std::endl;
}Here is the log when I run above function:
batch_prefill cuda in ops_api
batch_prefill cuda run end
!!!!!!! Segfault encountered !!!!!!!
File "src/ffi/backtrace.cc", line 154, in TVMFFISegFaultHandler(int)
File "./signal/../sysdeps/unix/sysv/linux/x86_64/libc_sigaction.c", line 0, in 0x00007f3f32b7f51f
According to the log, we can find cuda::batch_prefill already run all code because batch_prefill cuda run end is printed, but batch_prefill cuda end in ops_api is not printed and got a Segfault encountered.
I have no idea why cuda::batch_prefill is fininshed but segfault encountered.