Skip to content

Commit b977c0e

Browse files
committed
Add with_output version AppendAttention
1 parent c011cb8 commit b977c0e

File tree

8 files changed

+1332
-94
lines changed

8 files changed

+1332
-94
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 397 additions & 38 deletions
Large diffs are not rendered by default.

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,48 @@ std::vector<paddle::Tensor> AppendAttention(
9090
const int speculate_max_draft_token_num, const bool causal,
9191
const bool speculate_decoder);
9292

93+
void AppendAttentionWithOutput(
94+
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
95+
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
96+
const paddle::Tensor &seq_lens_decoder,
97+
const paddle::Tensor &seq_lens_this_time,
98+
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
99+
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
100+
const paddle::Tensor &encoder_tile_ids_per_batch,
101+
const paddle::Tensor &encoder_num_blocks,
102+
const paddle::Tensor &kv_batch_ids,
103+
const paddle::Tensor &kv_tile_ids_per_batch,
104+
const paddle::Tensor &kv_num_blocks,
105+
const paddle::Tensor &decoder_batch_ids,
106+
const paddle::Tensor &decoder_tile_ids_per_batch,
107+
const paddle::Tensor &decoder_num_blocks,
108+
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
109+
paddle::Tensor &res,
110+
const paddle::optional<paddle::Tensor> &rotary_embs,
111+
const paddle::optional<paddle::Tensor> &attn_mask,
112+
const paddle::optional<paddle::Tensor> &qkv_bias,
113+
const paddle::optional<paddle::Tensor> &qkv_out_scales,
114+
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
115+
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
116+
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
117+
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
118+
const paddle::optional<paddle::Tensor> &cache_k_zp,
119+
const paddle::optional<paddle::Tensor> &cache_v_zp,
120+
const paddle::optional<paddle::Tensor> &out_linear_shifts,
121+
const paddle::optional<paddle::Tensor> &out_linear_smooths,
122+
const paddle::optional<paddle::Tensor> &kv_signal_data,
123+
const paddle::optional<paddle::Tensor>& q_norm_weight,
124+
const paddle::optional<paddle::Tensor>& k_norm_weight,
125+
const float rms_norm_eps,
126+
const std::string &compute_dtype, const std::string &cache_quant_type_str,
127+
const bool use_neox_rotary_style, const bool rope_3d,
128+
const int max_input_length, const float quant_max_bound,
129+
const float quant_min_bound, const float out_linear_in_scale,
130+
const int encoder_block_shape_q, const int decoder_block_shape_q,
131+
const int max_partition_size, const int encoder_max_partition_size,
132+
const int speculate_max_draft_token_num, const bool causal,
133+
const bool speculate_decoder);
134+
93135
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
94136
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
95137
const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q,
@@ -828,6 +870,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
828870
* append_attention
829871
*/
830872
m.def("append_attention", &AppendAttention, "append attention function");
873+
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
831874
/**
832875
* gqa_rope_write_cache.cu
833876
* gqa_rope_write_cache

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 146 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from fastdeploy.model_executor.layers.attention.ops import (
2626
append_attention,
27+
append_attention_with_output,
2728
get_block_shape_and_split_kv_block,
2829
init_kv_signal_per_query,
2930
init_signal_layerwise,
@@ -85,6 +86,7 @@ def __init__(
8586
head_dim: int,
8687
encoder_block_shape_q: int = -1,
8788
decoder_block_shape_q: int = -1,
89+
use_output: bool = True,
8890
) -> None:
8991
"""
9092
AppendAttentionBackend __init__
@@ -121,6 +123,7 @@ def __init__(
121123
fd_config.parallel_config.expert_parallel_rank = 0
122124

123125
self.rank, self.device_id = init_rank_and_device_id(fd_config)
126+
self.use_output = use_output
124127

125128
def init_attention_metadata(self, forward_meta: ForwardMeta):
126129
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -228,57 +231,147 @@ def forward_mixed(
228231
layer.layer_id + self.start_layer_index,
229232
)
230233

231-
res = append_attention(
232-
qkv,
233-
forward_meta.caches[2 * layer.layer_id],
234-
forward_meta.caches[2 * layer.layer_id + 1],
235-
forward_meta.seq_lens_encoder,
236-
forward_meta.seq_lens_decoder,
237-
forward_meta.seq_lens_this_time,
238-
forward_meta.batch_id_per_token,
239-
forward_meta.cu_seqlens_q,
240-
metadata.block_tables,
241-
metadata.encoder_batch_ids,
242-
metadata.encoder_tile_ids_per_batch,
243-
metadata.encoder_num_blocks,
244-
metadata.kv_batch_ids,
245-
metadata.kv_tile_ids_per_batch,
246-
metadata.kv_num_blocks,
247-
forward_meta.decoder_batch_ids,
248-
forward_meta.decoder_tile_ids_per_batch,
249-
forward_meta.decoder_num_blocks_cpu,
250-
forward_meta.max_len_tensor_cpu,
251-
metadata.max_len_kv,
252-
metadata.rotary_embs,
253-
metadata.attn_mask,
254-
layer.qkv_bias,
255-
layer.qkv_scale,
256-
getattr(layer, "cache_k_scale", None),
257-
getattr(layer, "cache_v_scale", None),
258-
getattr(layer, "cache_k_out_scale", None),
259-
getattr(layer, "cache_v_out_scale", None),
260-
getattr(layer, "cache_k_zp", None),
261-
getattr(layer, "cache_v_zp", None),
262-
layer.linear_shift,
263-
layer.linear_smooth,
264-
metadata.kv_signal_data_list[layer.layer_id],
265-
getattr(layer, "q_norm_weight", None),
266-
getattr(layer, "k_norm_weight", None),
267-
getattr(layer, "rms_norm_eps", 1e-6),
268-
metadata._fuse_kernel_compute_dtype,
269-
getattr(layer, "cache_quant_type_str", "none"),
270-
layer.use_neox_rotary_style,
271-
self.rope_3d,
272-
self.max_seq_len,
273-
getattr(layer, "quant_max_bound", 0.0),
274-
getattr(layer, "quant_min_bound", 0.0),
275-
getattr(layer, "out_scale", -1.0),
276-
self.encoder_block_shape_q,
277-
self.decoder_block_shape_q,
278-
metadata.max_partition_size,
279-
metadata.encoder_max_partition_size,
280-
self.speculate_max_draft_token_num + 1,
281-
self.causal,
282-
self.speculative_method is not None,
283-
)[0]
234+
if self.use_output:
235+
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
236+
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")
237+
compute_type = metadata._fuse_kernel_compute_dtype
238+
out_scale = getattr(layer, "out_scale", -1.0)
239+
# 1. get output datatype
240+
qkv_dtype = qkv.dtype
241+
if qkv_dtype == paddle.float16:
242+
D_type = paddle.float16
243+
elif qkv_dtype == paddle.bfloat16:
244+
D_type = paddle.bfloat16
245+
elif qkv_dtype == paddle.int32:
246+
if compute_type == "bf16":
247+
D_type = paddle.bfloat16
248+
elif compute_type == "fp16":
249+
D_type = paddle.float16
250+
else:
251+
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].")
252+
else:
253+
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].")
254+
# 2.Extract related parameters
255+
token_nums = qkv.shape[0]
256+
head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2
257+
q_num_heads = self.num_heads
258+
# 3. generate output tensor of different dtypes
259+
if out_scale > 0.0:
260+
if abs(quant_max_bound - 127) < 0.000001:
261+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
262+
elif abs(quant_max_bound - 448) < 0.000001:
263+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
264+
else:
265+
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
266+
else:
267+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)
268+
269+
append_attention_with_output(
270+
qkv,
271+
forward_meta.caches[2 * layer.layer_id],
272+
forward_meta.caches[2 * layer.layer_id + 1],
273+
forward_meta.seq_lens_encoder,
274+
forward_meta.seq_lens_decoder,
275+
forward_meta.seq_lens_this_time,
276+
forward_meta.batch_id_per_token,
277+
forward_meta.cu_seqlens_q,
278+
metadata.block_tables,
279+
metadata.encoder_batch_ids,
280+
metadata.encoder_tile_ids_per_batch,
281+
metadata.encoder_num_blocks,
282+
metadata.kv_batch_ids,
283+
metadata.kv_tile_ids_per_batch,
284+
metadata.kv_num_blocks,
285+
forward_meta.decoder_batch_ids,
286+
forward_meta.decoder_tile_ids_per_batch,
287+
forward_meta.decoder_num_blocks_cpu,
288+
forward_meta.max_len_tensor_cpu,
289+
metadata.max_len_kv,
290+
res,
291+
metadata.rotary_embs,
292+
metadata.attn_mask,
293+
layer.qkv_bias,
294+
layer.qkv_scale,
295+
getattr(layer, "cache_k_scale", None),
296+
getattr(layer, "cache_v_scale", None),
297+
getattr(layer, "cache_k_out_scale", None),
298+
getattr(layer, "cache_v_out_scale", None),
299+
getattr(layer, "cache_k_zp", None),
300+
getattr(layer, "cache_v_zp", None),
301+
layer.linear_shift,
302+
layer.linear_smooth,
303+
metadata.kv_signal_data_list[layer.layer_id],
304+
getattr(layer, "q_norm_weight", None),
305+
getattr(layer, "k_norm_weight", None),
306+
getattr(layer, "rms_norm_eps", 1e-6),
307+
metadata._fuse_kernel_compute_dtype,
308+
getattr(layer, "cache_quant_type_str", "none"),
309+
layer.use_neox_rotary_style,
310+
self.rope_3d,
311+
self.max_seq_len,
312+
getattr(layer, "quant_max_bound", 0.0),
313+
getattr(layer, "quant_min_bound", 0.0),
314+
getattr(layer, "out_scale", -1.0),
315+
self.encoder_block_shape_q,
316+
self.decoder_block_shape_q,
317+
metadata.max_partition_size,
318+
metadata.encoder_max_partition_size,
319+
self.speculate_max_draft_token_num + 1,
320+
self.causal,
321+
self.speculative_method is not None,
322+
)
323+
else:
324+
res = append_attention(
325+
qkv,
326+
forward_meta.caches[2 * layer.layer_id],
327+
forward_meta.caches[2 * layer.layer_id + 1],
328+
forward_meta.seq_lens_encoder,
329+
forward_meta.seq_lens_decoder,
330+
forward_meta.seq_lens_this_time,
331+
forward_meta.batch_id_per_token,
332+
forward_meta.cu_seqlens_q,
333+
metadata.block_tables,
334+
metadata.encoder_batch_ids,
335+
metadata.encoder_tile_ids_per_batch,
336+
metadata.encoder_num_blocks,
337+
metadata.kv_batch_ids,
338+
metadata.kv_tile_ids_per_batch,
339+
metadata.kv_num_blocks,
340+
forward_meta.decoder_batch_ids,
341+
forward_meta.decoder_tile_ids_per_batch,
342+
forward_meta.decoder_num_blocks_cpu,
343+
forward_meta.max_len_tensor_cpu,
344+
metadata.max_len_kv,
345+
metadata.rotary_embs,
346+
metadata.attn_mask,
347+
layer.qkv_bias,
348+
layer.qkv_scale,
349+
getattr(layer, "cache_k_scale", None),
350+
getattr(layer, "cache_v_scale", None),
351+
getattr(layer, "cache_k_out_scale", None),
352+
getattr(layer, "cache_v_out_scale", None),
353+
getattr(layer, "cache_k_zp", None),
354+
getattr(layer, "cache_v_zp", None),
355+
layer.linear_shift,
356+
layer.linear_smooth,
357+
metadata.kv_signal_data_list[layer.layer_id],
358+
getattr(layer, "q_norm_weight", None),
359+
getattr(layer, "k_norm_weight", None),
360+
getattr(layer, "rms_norm_eps", 1e-6),
361+
metadata._fuse_kernel_compute_dtype,
362+
getattr(layer, "cache_quant_type_str", "none"),
363+
layer.use_neox_rotary_style,
364+
self.rope_3d,
365+
self.max_seq_len,
366+
getattr(layer, "quant_max_bound", 0.0),
367+
getattr(layer, "quant_min_bound", 0.0),
368+
getattr(layer, "out_scale", -1.0),
369+
self.encoder_block_shape_q,
370+
self.decoder_block_shape_q,
371+
metadata.max_partition_size,
372+
metadata.encoder_max_partition_size,
373+
self.speculate_max_draft_token_num + 1,
374+
self.causal,
375+
self.speculative_method is not None,
376+
)
284377
return res

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def forward_mixed(
377377
self.speculate_max_draft_token_num + 1,
378378
self.causal,
379379
self.speculative_method is not None,
380-
)[0]
380+
)
381381

382382
if metadata.max_len_tensor_cpu[1] > 0:
383383
merge_prefill_decode_output(

fastdeploy/model_executor/layers/attention/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17-
from .append_attention import append_attention
17+
from .append_attention import append_attention, append_attention_with_output
1818
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
1919
from .gqa_rope_write_cache import gqa_rope_write_cache
2020
from .init_kv_signal_per_query import init_kv_signal_per_query
@@ -25,6 +25,7 @@
2525
__all__ = [
2626
"get_block_shape_and_split_kv_block",
2727
"append_attention",
28+
"append_attention_with_output",
2829
"open_shm_and_get_meta_signal",
2930
"init_signal_layerwise",
3031
"gqa_rope_write_cache",

0 commit comments

Comments
 (0)