Skip to content

Commit 86d3e13

Browse files
yzh119cyx-6Kathryn-cat
authored
refactor: using tvm-ffi for multi-platform bindings (#1641)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR refactors the codebase to use tvm-ffi for python bindings. The goal is as following: 1. Supporting multiple backends 2. PyTorch version agnostic aot wheels 3. Accelerate compilation time and reduce binary bloat by minimizing dependencies Co-authored-by: Yaxing Cai<[email protected]> Co-authored-by: Kathryn Chen<[email protected]> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @tqchen --------- Co-authored-by: Yaxing Cai <[email protected]> Co-authored-by: Kathryn-cat <[email protected]> Co-authored-by: Kathryn-cat <[email protected]> Co-authored-by: Yaxing Cai <[email protected]>
1 parent ea56964 commit 86d3e13

File tree

155 files changed

+5936
-6712
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+5936
-6712
lines changed

ci/bash.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ echo "ENV VARIABLES: ${DOCKER_ENV}"
7070
echo "VOLUMES: ${DOCKER_VOLUMNS}"
7171
echo "COMMANDS: '${COMMAND[@]}'"
7272

73+
# Pull the latest docker image
74+
echo "Pulling latest image: ${DOCKER_IMAGE_NAME}"
75+
${DOCKER_BINARY} pull ${DOCKER_IMAGE_NAME}
76+
7377
# By default we cleanup - remove the container once it finish running (--rm)
7478
# and share the PID namespace (--pid=host) so the process inside does not have
7579
# pid 1 and SIGKILL is propagated to the process inside (jenkins can kill it).

csrc/activation.cu

Lines changed: 0 additions & 128 deletions
This file was deleted.

csrc/aot_extension_utils.h

Lines changed: 0 additions & 50 deletions
This file was deleted.

csrc/batch_attention.cu

Lines changed: 51 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
#include <flashinfer/attention/scheduler.cuh>
1818
#include <flashinfer/layout.cuh>
1919
#include <flashinfer/pos_enc.cuh>
20-
#include <optional>
2120

2221
#include "batch_attention_config.inc"
23-
#include "pytorch_conversion_utils.h"
24-
#include "pytorch_extension_utils.h"
22+
#include "tvm_ffi_utils.h"
2523

2624
namespace flashinfer {
2725

26+
using tvm::ffi::Array;
27+
using tvm::ffi::Optional;
28+
2829
template <uint32_t CTA_TILE_Q_1, uint32_t CTA_TILE_Q_2, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
2930
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
3031
cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2,
@@ -34,80 +35,73 @@ cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params pa
3435

3536
using namespace flashinfer;
3637

37-
at::Tensor BatchPagedAttentionPlan(at::Tensor float_workspace_buffer,
38-
at::Tensor int_workspace_buffer,
39-
at::Tensor page_locked_int_workspace_buffer,
40-
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
41-
int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads,
42-
int64_t head_dim_o, bool causal) {
38+
Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
39+
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
40+
Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
41+
int64_t num_qo_heads, int64_t num_kv_heads,
42+
int64_t head_dim_o, bool causal) {
4343
size_t float_workspace_size_in_bytes =
44-
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
44+
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
4545
size_t int_workspace_size_in_bytes =
46-
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
46+
int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer);
4747

4848
HolisticPlanInfo<2> plan_info;
4949

50-
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
51-
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
50+
cudaSetDevice(float_workspace_buffer->device.device_id);
51+
const cudaStream_t stream = get_stream(float_workspace_buffer->device);
5252

5353
cudaError_t status = TwoStageHolisticPlan<IdType>(
54-
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
55-
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
56-
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
57-
kv_indptr.data_ptr<IdType>(), kv_len.data_ptr<IdType>(), batch_size, num_qo_heads,
58-
num_kv_heads, head_dim_o, causal, stream);
54+
float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data,
55+
page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info,
56+
static_cast<IdType*>(qo_indptr->data), static_cast<IdType*>(kv_indptr->data),
57+
static_cast<IdType*>(kv_len->data), batch_size, num_qo_heads, num_kv_heads, head_dim_o,
58+
causal, stream);
5959

60-
TORCH_CHECK(status == cudaSuccess,
61-
"Failed to plan persistent paged attention, error: ", cudaGetErrorString(status));
60+
TVM_FFI_ICHECK(status == cudaSuccess)
61+
<< "Failed to plan persistent paged attention, error: " << cudaGetErrorString(status);
6262

63-
return vec_to_tensor(plan_info.ToVector());
63+
return Array(plan_info.ToVector());
6464
}
6565

66-
void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
67-
at::Tensor plan_info_vec, at::Tensor q, at::Tensor k_cache,
68-
at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
69-
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
70-
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
71-
int64_t page_size,
66+
void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
67+
Array<int64_t> plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
68+
Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
69+
int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
70+
int64_t num_kv_heads, int64_t page_size,
7271
double v_scale, // must use double due to pytorch binding
7372
double sm_scale,
7473
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
7574
HolisticPlanInfo<2> plan_info;
76-
plan_info.FromVector(tensor_to_vec(plan_info_vec));
77-
78-
auto device = q.device();
75+
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
7976

80-
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
81-
void* int_buffer_ptr = int_workspace_buffer.data_ptr();
77+
void* float_buffer_ptr = float_workspace_buffer->data;
78+
void* int_buffer_ptr = int_workspace_buffer->data;
8279

8380
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
8481

85-
auto q_scalar_type = q.scalar_type();
86-
auto kv_scalar_type = k_cache.scalar_type();
87-
8882
// NOTE (Yilong): assume both q and o are NHD
89-
unsigned int q_stride_n = q.stride(0);
90-
unsigned int q_stride_h = q.stride(1);
83+
unsigned int q_stride_n = q->strides[0];
84+
unsigned int q_stride_h = q->strides[1];
9185

9286
// layout only constraint paged KV
9387
const QKVLayout kv_layout = static_cast<QKVLayout>(layout_code);
94-
unsigned int k_stride_page = k_cache.stride(0);
95-
unsigned int v_stride_page = v_cache.stride(0);
88+
unsigned int k_stride_page = k_cache->strides[0];
89+
unsigned int v_stride_page = v_cache->strides[0];
9690
unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h;
9791
if (kv_layout == QKVLayout::kNHD) {
98-
k_stride_h = k_cache.stride(2);
99-
k_stride_n = k_cache.stride(1);
100-
v_stride_h = v_cache.stride(2);
101-
v_stride_n = v_cache.stride(1);
92+
k_stride_h = k_cache->strides[2];
93+
k_stride_n = k_cache->strides[1];
94+
v_stride_h = v_cache->strides[2];
95+
v_stride_n = v_cache->strides[1];
10296
} else {
103-
k_stride_h = k_cache.stride(1);
104-
k_stride_n = k_cache.stride(2);
105-
v_stride_h = v_cache.stride(1);
106-
v_stride_n = v_cache.stride(2);
97+
k_stride_h = k_cache->strides[1];
98+
k_stride_n = k_cache->strides[2];
99+
v_stride_h = v_cache->strides[1];
100+
v_stride_n = v_cache->strides[2];
107101
}
108102

109-
const c10::cuda::OptionalCUDAGuard device_guard(device);
110-
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
103+
cudaSetDevice(q->device.device_id);
104+
const cudaStream_t stream = get_stream(q->device);
111105

112106
DISPATCH_context(
113107
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
@@ -116,17 +110,17 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
116110
IdType* len_kv_chunk =
117111
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.len_kv_chunk_offset);
118112
for (int i = 0; i < 2; i++) {
119-
params[i].q = static_cast<DTypeQ*>(q.data_ptr());
120-
params[i].k = static_cast<DTypeKV*>(k_cache.data_ptr());
121-
params[i].v = static_cast<DTypeKV*>(v_cache.data_ptr());
113+
params[i].q = static_cast<DTypeQ*>(q->data);
114+
params[i].k = static_cast<DTypeKV*>(k_cache->data);
115+
params[i].v = static_cast<DTypeKV*>(v_cache->data);
122116

123117
params[i].q_indptr =
124118
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_indptr_offset);
125119
params[i].kv_indptr =
126120
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_indptr_offset);
127121
params[i].partial_indptr = GetPtrFromBaseOffset<IdType>(
128122
int_buffer_ptr, plan_info.tasks[i].partial_indptr_offset);
129-
params[i].kv_indices = static_cast<int*>(kv_indices.data_ptr());
123+
params[i].kv_indices = static_cast<int*>(kv_indices->data);
130124
params[i].q_len =
131125
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_len_offset);
132126
params[i].kv_len =
@@ -143,9 +137,9 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
143137
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset);
144138
params[i].len_kv_chunk = len_kv_chunk + i;
145139

146-
params[i].final_o = static_cast<DTypeO*>(o.data_ptr());
140+
params[i].final_o = static_cast<DTypeO*>(o->data);
147141
params[i].final_lse =
148-
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
142+
maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value()->data) : nullptr;
149143
params[i].partial_o =
150144
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
151145
params[i].partial_lse =
@@ -184,8 +178,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
184178
cudaError_t status = BatchPagedAttentionPersistent<128, 16, HEAD_DIM_QK, HEAD_DIM_VO,
185179
MASK_MODE, AttentionVariant>(
186180
params[0], params[1], plan_info.num_blks_x, plan_info.num_blks_y, stream);
187-
TORCH_CHECK(status == cudaSuccess, "Failed to run persistent paged attention, error: ",
188-
cudaGetErrorString(status));
181+
TVM_FFI_ICHECK(status == cudaSuccess)
182+
<< "Failed to run persistent paged attention, error: " << cudaGetErrorString(status);
189183
return true;
190184
});
191185
}

0 commit comments

Comments
 (0)