@@ -27,6 +27,8 @@ void PagedAttnKernel(const paddle::Tensor& q,
27
27
const paddle::optional<paddle::Tensor> &v,
28
28
const paddle::optional<paddle::Tensor> &rope_sin,
29
29
const paddle::optional<paddle::Tensor> &rope_cos,
30
+ int num_heads,
31
+ int head_dim,
30
32
int num_kv_heads,
31
33
float scale,
32
34
int block_size,
@@ -86,32 +88,36 @@ void PagedAttnKernel(const paddle::Tensor& q,
86
88
common::errors::InvalidArgument (
87
89
" paged_attention expects seq_lens is contiguous" ));
88
90
// check dim and shape
89
- // k_cache: [num_blocks, kv_num_heads, block_size, head_size ]
90
- // v_cache: [num_blocks, kv_num_heads, block_size, head_size ]
91
+ // k_cache: [num_blocks, kv_num_heads, block_size, head_dim ]
92
+ // v_cache: [num_blocks, kv_num_heads, block_size, head_dim ]
91
93
// block_table: [num_seqs, max_num_blocks_per_seq]
92
94
// seq_lens: [num_seqs]
93
95
// q and out:
94
- // merged_qkv = false: [num_seqs, num_heads, head_size]
95
- // merged_qkv = true: [num_seqs, num_heads+2*num_kv_heads, head_size]
96
+ // if merged_qkv = false:
97
+ // q:[num_seqs, hidden_size]
98
+ // out:[num_seqs, hidden_size]
99
+ // if merged_qkv = true:
100
+ // q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim]
101
+ // out: [num_seqs, hidden_size]
96
102
97
103
const auto & q_dims = q.dims ();
98
104
PADDLE_ENFORCE_EQ (q_dims.size (),
99
- 3 ,
105
+ 2 ,
100
106
common::errors::InvalidArgument (
101
107
" paged_attn receive query dims is "
102
- " [num_seqs, num_heads, head_size ]" ));
108
+ " [num_seqs, ( num_heads+2*num_kv_heads)*head_dim ]" ));
103
109
PADDLE_ENFORCE_EQ (out.dims ().size (),
104
- 3 ,
110
+ 2 ,
105
111
common::errors::InvalidArgument (
106
112
" paged_attn receive out dims is "
107
- " [num_seqs, num_heads, head_size ]" ));
113
+ " [num_seqs, hidden_size ]" ));
108
114
109
115
const auto & kv_cache_dims = k_cache.dims ();
110
116
PADDLE_ENFORCE_EQ (kv_cache_dims.size (),
111
117
4 ,
112
118
common::errors::InvalidArgument (
113
119
" paged_attn receive kv cache dims is "
114
- " [num_blocks, kv_num_heads, block_size, head_size ]" ));
120
+ " [num_blocks, kv_num_heads, block_size, head_dim ]" ));
115
121
116
122
const auto & block_table_dims = block_table.dims ();
117
123
PADDLE_ENFORCE_EQ (block_table_dims.size (),
@@ -127,8 +133,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
127
133
" paged_attn receive seq_lens dims is [num_seqs]" ));
128
134
129
135
int num_seqs = q_dims[0 ];
130
- int num_heads = merged_qkv ? q_dims[1 ] - 2 * num_kv_heads : q_dims[1 ];
131
- int head_size = q_dims[2 ];
132
136
int max_num_blocks_per_seq = block_table_dims[1 ];
133
137
int q_stride = q.strides ()[0 ];
134
138
int num_blocks = kv_cache_dims[0 ];
@@ -142,9 +146,9 @@ void PagedAttnKernel(const paddle::Tensor& q,
142
146
common::errors::InvalidArgument (
143
147
" kv_cache_dims[2] must be equal to block_size" ));
144
148
PADDLE_ENFORCE_EQ (kv_cache_dims[3 ],
145
- head_size ,
149
+ head_dim ,
146
150
common::errors::InvalidArgument (
147
- " kv_cache_dims[3] must be equal to head_size " ));
151
+ " kv_cache_dims[3] must be equal to head_dim " ));
148
152
PADDLE_ENFORCE_EQ (block_table_dims[0 ],
149
153
num_seqs,
150
154
common::errors::InvalidArgument (
@@ -162,14 +166,13 @@ void PagedAttnKernel(const paddle::Tensor& q,
162
166
const float *rope_sin_ptr = merged_qkv ? rope_sin.get ().data <float >() : nullptr ;
163
167
const float *rope_cos_ptr = merged_qkv ? rope_cos.get ().data <float >() : nullptr ;
164
168
165
- auto dev_ctx = static_cast <const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance ().Get (q.place ()));
166
169
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance ()->getIxInferHandle ();
167
170
168
171
size_t workspace_size = 0 ;
169
172
CUINFER_CHECK (cuInferPageAttentionGetWorkspaceV7 (num_seqs,
170
173
num_heads,
171
174
num_kv_heads,
172
- head_size ,
175
+ head_dim ,
173
176
block_size,
174
177
max_context_len,
175
178
&workspace_size));
@@ -189,7 +192,7 @@ void PagedAttnKernel(const paddle::Tensor& q,
189
192
num_seqs,
190
193
num_heads,
191
194
num_kv_heads,
192
- head_size ,
195
+ head_dim ,
193
196
q_stride,
194
197
kv_block_stride,
195
198
kv_head_stride,
@@ -215,6 +218,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
215
218
const paddle::optional<paddle::Tensor> &v,
216
219
const paddle::optional<paddle::Tensor> &rope_sin,
217
220
const paddle::optional<paddle::Tensor> &rope_cos,
221
+ int num_heads,
222
+ int head_dim,
218
223
int num_kv_heads,
219
224
float scale,
220
225
int block_size,
@@ -228,11 +233,7 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
228
233
bool merged_qkv) {
229
234
230
235
const auto dtype = q.dtype ();
231
- auto out_shape = q.shape ();
232
- if (merged_qkv) {
233
- out_shape[1 ] -= 2 * num_kv_heads;
234
- }
235
- auto out = paddle::empty (out_shape, dtype, q.place ());
236
+ auto out = paddle::empty ({q.shape ()[0 ], num_heads * head_dim}, dtype, q.place ());
236
237
237
238
switch (dtype) {
238
239
case paddle::DataType::BFLOAT16:
@@ -246,6 +247,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
246
247
v,
247
248
rope_sin,
248
249
rope_cos,
250
+ num_heads,
251
+ head_dim,
249
252
num_kv_heads,
250
253
scale,
251
254
block_size,
@@ -270,6 +273,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
270
273
v,
271
274
rope_sin,
272
275
rope_cos,
276
+ num_heads,
277
+ head_dim,
273
278
num_kv_heads,
274
279
scale,
275
280
block_size,
@@ -299,6 +304,8 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
299
304
const std::vector<int64_t >& v_shape,
300
305
const std::vector<int64_t >& rope_sin_shape,
301
306
const std::vector<int64_t >& rope_cos_shape,
307
+ int num_heads,
308
+ int head_dim,
302
309
int num_kv_heads,
303
310
float scale,
304
311
int block_size,
@@ -311,36 +318,13 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
311
318
bool use_sqrt_alibi,
312
319
bool merged_qkv) {
313
320
if (merged_qkv) {
314
- int64_t num_tokens = q_shape[0 ];
315
- int64_t num_heads = q_shape[1 ] - 2 * num_kv_heads;
316
- int64_t head_dim = q_shape[2 ];
317
- return {{num_tokens, num_heads, head_dim}};
321
+ return {{q_shape[0 ], num_heads * head_dim}};
318
322
} else {
319
323
return {q_shape};
320
324
}
321
325
}
322
326
323
- std::vector<paddle::DataType> PagedAttnInferDtype (const paddle::DataType& q_dtype,
324
- const paddle::DataType& k_cache_dtype,
325
- const paddle::DataType& v_cache_dtype,
326
- const paddle::DataType& block_table_dtype,
327
- const paddle::DataType& seq_lens_dtype,
328
- const paddle::DataType& alibi_slopes_dtype,
329
- const paddle::DataType& k_dtype,
330
- const paddle::DataType& v_dtype,
331
- const paddle::DataType& rope_sin_dtype,
332
- const paddle::DataType& rope_cos_dtype,
333
- int num_kv_heads,
334
- float scale,
335
- int block_size,
336
- int max_context_len,
337
- bool causal,
338
- int window_left,
339
- int window_right,
340
- float softcap,
341
- bool enable_cuda_graph,
342
- bool use_sqrt_alibi,
343
- bool merged_qkv) {
327
+ std::vector<paddle::DataType> PagedAttnInferDtype (const paddle::DataType& q_dtype) {
344
328
return {q_dtype};
345
329
}
346
330
@@ -351,7 +335,9 @@ PD_BUILD_STATIC_OP(paged_attn)
351
335
paddle::Optional (" v" ), paddle::Optional (" rope_sin" ),
352
336
paddle::Optional (" rope_cos" )})
353
337
.Outputs({" out" })
354
- .Attrs({" num_kv_heads:int" ,
338
+ .Attrs({" num_heads:int" ,
339
+ " head_dim:int" ,
340
+ " num_kv_heads:int" ,
355
341
" scale:float" ,
356
342
" block_size:int" ,
357
343
" max_context_len:int" ,
0 commit comments