17
17
#include < flashinfer/attention/scheduler.cuh>
18
18
#include < flashinfer/layout.cuh>
19
19
#include < flashinfer/pos_enc.cuh>
20
- #include < optional>
21
20
22
21
#include " batch_attention_config.inc"
23
- #include " pytorch_conversion_utils.h"
24
- #include " pytorch_extension_utils.h"
22
+ #include " tvm_ffi_utils.h"
25
23
26
24
namespace flashinfer {
27
25
26
+ using tvm::ffi::Array;
27
+ using tvm::ffi::Optional;
28
+
28
29
template <uint32_t CTA_TILE_Q_1, uint32_t CTA_TILE_Q_2, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
29
30
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
30
31
cudaError_t BatchPagedAttentionPersistent (const Params params_1, const Params params_2,
@@ -34,80 +35,73 @@ cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params pa
34
35
35
36
using namespace flashinfer ;
36
37
37
- at::Tensor BatchPagedAttentionPlan (at::Tensor float_workspace_buffer,
38
- at::Tensor int_workspace_buffer,
39
- at::Tensor page_locked_int_workspace_buffer,
40
- at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
41
- int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads,
42
- int64_t head_dim_o, bool causal) {
38
+ Array<int64_t > BatchPagedAttentionPlan (Tensor float_workspace_buffer, Tensor int_workspace_buffer,
39
+ Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
40
+ Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
41
+ int64_t num_qo_heads, int64_t num_kv_heads,
42
+ int64_t head_dim_o, bool causal) {
43
43
size_t float_workspace_size_in_bytes =
44
- float_workspace_buffer. size ( 0 ) * float_workspace_buffer. element_size ( );
44
+ float_workspace_buffer-> shape [ 0 ] * get_element_size (float_workspace_buffer );
45
45
size_t int_workspace_size_in_bytes =
46
- int_workspace_buffer. size ( 0 ) * int_workspace_buffer. element_size ( );
46
+ int_workspace_buffer-> shape [ 0 ] * get_element_size (int_workspace_buffer );
47
47
48
48
HolisticPlanInfo<2 > plan_info;
49
49
50
- const c10::cuda::OptionalCUDAGuard device_guard (float_workspace_buffer. device () );
51
- const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ( );
50
+ cudaSetDevice (float_workspace_buffer-> device . device_id );
51
+ const cudaStream_t stream = get_stream (float_workspace_buffer-> device );
52
52
53
53
cudaError_t status = TwoStageHolisticPlan<IdType>(
54
- float_workspace_buffer. data_ptr () , float_workspace_size_in_bytes,
55
- int_workspace_buffer. data_ptr (), page_locked_int_workspace_buffer. data_ptr () ,
56
- int_workspace_size_in_bytes, plan_info, qo_indptr. data_ptr <IdType>( ),
57
- kv_indptr. data_ptr <IdType>(), kv_len. data_ptr <IdType>( ), batch_size, num_qo_heads,
58
- num_kv_heads, head_dim_o, causal, stream);
54
+ float_workspace_buffer-> data , float_workspace_size_in_bytes, int_workspace_buffer-> data ,
55
+ page_locked_int_workspace_buffer-> data , int_workspace_size_in_bytes, plan_info ,
56
+ static_cast <IdType*>(qo_indptr-> data ), static_cast <IdType*>(kv_indptr-> data ),
57
+ static_cast <IdType*>( kv_len-> data ), batch_size, num_qo_heads, num_kv_heads, head_dim_o ,
58
+ causal, stream);
59
59
60
- TORCH_CHECK (status == cudaSuccess,
61
- " Failed to plan persistent paged attention, error: " , cudaGetErrorString (status) );
60
+ TVM_FFI_ICHECK (status == cudaSuccess)
61
+ << " Failed to plan persistent paged attention, error: " << cudaGetErrorString (status);
62
62
63
- return vec_to_tensor (plan_info.ToVector ());
63
+ return Array (plan_info.ToVector ());
64
64
}
65
65
66
- void BatchPagedAttentionRun (at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
67
- at::Tensor plan_info_vec, at::Tensor q, at::Tensor k_cache,
68
- at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o,
69
- std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
70
- int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
71
- int64_t page_size,
66
+ void BatchPagedAttentionRun (Tensor float_workspace_buffer, Tensor int_workspace_buffer,
67
+ Array<int64_t > plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
68
+ Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
69
+ int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
70
+ int64_t num_kv_heads, int64_t page_size,
72
71
double v_scale, // must use double due to pytorch binding
73
72
double sm_scale,
74
73
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
75
74
HolisticPlanInfo<2 > plan_info;
76
- plan_info.FromVector (tensor_to_vec (plan_info_vec));
77
-
78
- auto device = q.device ();
75
+ plan_info.FromVector (std::vector<int64_t >(plan_info_vec.begin (), plan_info_vec.end ()));
79
76
80
- void * float_buffer_ptr = float_workspace_buffer. data_ptr () ;
81
- void * int_buffer_ptr = int_workspace_buffer. data_ptr () ;
77
+ void * float_buffer_ptr = float_workspace_buffer-> data ;
78
+ void * int_buffer_ptr = int_workspace_buffer-> data ;
82
79
83
80
const MaskMode mask_mode = static_cast <MaskMode>(mask_mode_code);
84
81
85
- auto q_scalar_type = q.scalar_type ();
86
- auto kv_scalar_type = k_cache.scalar_type ();
87
-
88
82
// NOTE (Yilong): assume both q and o are NHD
89
- unsigned int q_stride_n = q. stride ( 0 ) ;
90
- unsigned int q_stride_h = q. stride ( 1 ) ;
83
+ unsigned int q_stride_n = q-> strides [ 0 ] ;
84
+ unsigned int q_stride_h = q-> strides [ 1 ] ;
91
85
92
86
// layout only constraint paged KV
93
87
const QKVLayout kv_layout = static_cast <QKVLayout>(layout_code);
94
- unsigned int k_stride_page = k_cache. stride ( 0 ) ;
95
- unsigned int v_stride_page = v_cache. stride ( 0 ) ;
88
+ unsigned int k_stride_page = k_cache-> strides [ 0 ] ;
89
+ unsigned int v_stride_page = v_cache-> strides [ 0 ] ;
96
90
unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h;
97
91
if (kv_layout == QKVLayout::kNHD ) {
98
- k_stride_h = k_cache. stride ( 2 ) ;
99
- k_stride_n = k_cache. stride ( 1 ) ;
100
- v_stride_h = v_cache. stride ( 2 ) ;
101
- v_stride_n = v_cache. stride ( 1 ) ;
92
+ k_stride_h = k_cache-> strides [ 2 ] ;
93
+ k_stride_n = k_cache-> strides [ 1 ] ;
94
+ v_stride_h = v_cache-> strides [ 2 ] ;
95
+ v_stride_n = v_cache-> strides [ 1 ] ;
102
96
} else {
103
- k_stride_h = k_cache. stride ( 1 ) ;
104
- k_stride_n = k_cache. stride ( 2 ) ;
105
- v_stride_h = v_cache. stride ( 1 ) ;
106
- v_stride_n = v_cache. stride ( 2 ) ;
97
+ k_stride_h = k_cache-> strides [ 1 ] ;
98
+ k_stride_n = k_cache-> strides [ 2 ] ;
99
+ v_stride_h = v_cache-> strides [ 1 ] ;
100
+ v_stride_n = v_cache-> strides [ 2 ] ;
107
101
}
108
102
109
- const c10::cuda::OptionalCUDAGuard device_guard ( device);
110
- const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ( );
103
+ cudaSetDevice (q-> device . device_id );
104
+ const cudaStream_t stream = get_stream (q-> device );
111
105
112
106
DISPATCH_context (
113
107
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
@@ -116,17 +110,17 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
116
110
IdType* len_kv_chunk =
117
111
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.len_kv_chunk_offset );
118
112
for (int i = 0 ; i < 2 ; i++) {
119
- params[i].q = static_cast <DTypeQ*>(q. data_ptr () );
120
- params[i].k = static_cast <DTypeKV*>(k_cache. data_ptr () );
121
- params[i].v = static_cast <DTypeKV*>(v_cache. data_ptr () );
113
+ params[i].q = static_cast <DTypeQ*>(q-> data );
114
+ params[i].k = static_cast <DTypeKV*>(k_cache-> data );
115
+ params[i].v = static_cast <DTypeKV*>(v_cache-> data );
122
116
123
117
params[i].q_indptr =
124
118
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks [i].q_indptr_offset );
125
119
params[i].kv_indptr =
126
120
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks [i].kv_indptr_offset );
127
121
params[i].partial_indptr = GetPtrFromBaseOffset<IdType>(
128
122
int_buffer_ptr, plan_info.tasks [i].partial_indptr_offset );
129
- params[i].kv_indices = static_cast <int *>(kv_indices. data_ptr () );
123
+ params[i].kv_indices = static_cast <int *>(kv_indices-> data );
130
124
params[i].q_len =
131
125
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks [i].q_len_offset );
132
126
params[i].kv_len =
@@ -143,9 +137,9 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
143
137
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks [i].work_indptr_offset );
144
138
params[i].len_kv_chunk = len_kv_chunk + i;
145
139
146
- params[i].final_o = static_cast <DTypeO*>(o. data_ptr () );
140
+ params[i].final_o = static_cast <DTypeO*>(o-> data );
147
141
params[i].final_lse =
148
- maybe_lse.has_value () ? static_cast <float *>(maybe_lse-> data_ptr () ) : nullptr ;
142
+ maybe_lse.has_value () ? static_cast <float *>(maybe_lse. value ()-> data ) : nullptr ;
149
143
params[i].partial_o =
150
144
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset );
151
145
params[i].partial_lse =
@@ -184,8 +178,8 @@ void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_wo
184
178
cudaError_t status = BatchPagedAttentionPersistent<128 , 16 , HEAD_DIM_QK, HEAD_DIM_VO,
185
179
MASK_MODE, AttentionVariant>(
186
180
params[0 ], params[1 ], plan_info.num_blks_x , plan_info.num_blks_y , stream);
187
- TORCH_CHECK (status == cudaSuccess, " Failed to run persistent paged attention, error: " ,
188
- cudaGetErrorString (status) );
181
+ TVM_FFI_ICHECK (status == cudaSuccess)
182
+ << " Failed to run persistent paged attention, error: " << cudaGetErrorString (status);
189
183
return true ;
190
184
});
191
185
}
0 commit comments