Skip to content

Commit e93d4cf

Browse files
authored
Add with_output version AppendAttention (#3302)
* get use_output from fd_config * add clear TODO description * add mask_offset para to align with develop * fix bug * fix use_output logic * fix sot bug
1 parent 94ded43 commit e93d4cf

File tree

8 files changed

+1366
-96
lines changed

8 files changed

+1366
-96
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 408 additions & 39 deletions
Large diffs are not rendered by default.

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,49 @@ std::vector<paddle::Tensor> AppendAttention(
9191
const int speculate_max_draft_token_num, const bool causal,
9292
const bool speculate_decoder);
9393

94+
void AppendAttentionWithOutput(
95+
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
96+
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
97+
const paddle::Tensor &seq_lens_decoder,
98+
const paddle::Tensor &seq_lens_this_time,
99+
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
100+
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
101+
const paddle::Tensor &encoder_tile_ids_per_batch,
102+
const paddle::Tensor &encoder_num_blocks,
103+
const paddle::Tensor &kv_batch_ids,
104+
const paddle::Tensor &kv_tile_ids_per_batch,
105+
const paddle::Tensor &kv_num_blocks,
106+
const paddle::Tensor &decoder_batch_ids,
107+
const paddle::Tensor &decoder_tile_ids_per_batch,
108+
const paddle::Tensor &decoder_num_blocks,
109+
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
110+
paddle::Tensor &fmha_out,
111+
const paddle::optional<paddle::Tensor> &rotary_embs,
112+
const paddle::optional<paddle::Tensor> &attn_mask,
113+
const paddle::optional<paddle::Tensor> &qkv_bias,
114+
const paddle::optional<paddle::Tensor> &qkv_out_scales,
115+
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
116+
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
117+
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
118+
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
119+
const paddle::optional<paddle::Tensor> &cache_k_zp,
120+
const paddle::optional<paddle::Tensor> &cache_v_zp,
121+
const paddle::optional<paddle::Tensor> &out_linear_shifts,
122+
const paddle::optional<paddle::Tensor> &out_linear_smooths,
123+
const paddle::optional<paddle::Tensor> &mask_offset,
124+
const paddle::optional<paddle::Tensor> &kv_signal_data,
125+
const paddle::optional<paddle::Tensor>& q_norm_weight,
126+
const paddle::optional<paddle::Tensor>& k_norm_weight,
127+
const float rms_norm_eps,
128+
const std::string &compute_dtype, const std::string &cache_quant_type_str,
129+
const bool use_neox_rotary_style, const bool rope_3d,
130+
const int max_input_length, const float quant_max_bound,
131+
const float quant_min_bound, const float out_linear_in_scale,
132+
const int encoder_block_shape_q, const int decoder_block_shape_q,
133+
const int max_partition_size, const int encoder_max_partition_size,
134+
const int speculate_max_draft_token_num, const bool causal,
135+
const bool speculate_decoder);
136+
94137
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
95138
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
96139
const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q,
@@ -881,6 +924,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
881924
* append_attention
882925
*/
883926
m.def("append_attention", &AppendAttention, "append attention function");
927+
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
884928
/**
885929
* gqa_rope_write_cache.cu
886930
* gqa_rope_write_cache

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 147 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from fastdeploy.model_executor.layers.attention.ops import (
2626
append_attention,
27+
append_attention_with_output,
2728
get_block_shape_and_split_kv_block,
2829
init_kv_signal_per_query,
2930
init_signal_layerwise,
@@ -122,6 +123,7 @@ def __init__(
122123
fd_config.parallel_config.expert_parallel_rank = 0
123124

124125
self.rank, self.device_id = init_rank_and_device_id(fd_config)
126+
self.use_output = not fd_config.graph_opt_config.full_cuda_graph
125127

126128
def init_attention_metadata(self, forward_meta: ForwardMeta):
127129
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -229,58 +231,149 @@ def forward_mixed(
229231
layer.layer_id + self.start_layer_index,
230232
)
231233

232-
res = append_attention(
233-
qkv,
234-
forward_meta.caches[2 * layer.layer_id],
235-
forward_meta.caches[2 * layer.layer_id + 1],
236-
forward_meta.seq_lens_encoder,
237-
forward_meta.seq_lens_decoder,
238-
forward_meta.seq_lens_this_time,
239-
forward_meta.batch_id_per_token,
240-
forward_meta.cu_seqlens_q,
241-
metadata.block_tables,
242-
metadata.encoder_batch_ids,
243-
metadata.encoder_tile_ids_per_batch,
244-
metadata.encoder_num_blocks,
245-
metadata.kv_batch_ids,
246-
metadata.kv_tile_ids_per_batch,
247-
metadata.kv_num_blocks,
248-
forward_meta.decoder_batch_ids,
249-
forward_meta.decoder_tile_ids_per_batch,
250-
forward_meta.decoder_num_blocks_cpu,
251-
forward_meta.max_len_tensor_cpu,
252-
metadata.max_len_kv,
253-
metadata.rotary_embs,
254-
metadata.attn_mask,
255-
layer.qkv_bias,
256-
layer.qkv_scale,
257-
getattr(layer, "cache_k_scale", None),
258-
getattr(layer, "cache_v_scale", None),
259-
getattr(layer, "cache_k_out_scale", None),
260-
getattr(layer, "cache_v_out_scale", None),
261-
getattr(layer, "cache_k_zp", None),
262-
getattr(layer, "cache_v_zp", None),
263-
layer.linear_shift,
264-
layer.linear_smooth,
265-
metadata.mask_offset,
266-
metadata.kv_signal_data_list[layer.layer_id],
267-
getattr(layer, "q_norm_weight", None),
268-
getattr(layer, "k_norm_weight", None),
269-
getattr(layer, "rms_norm_eps", 1e-6),
270-
metadata._fuse_kernel_compute_dtype,
271-
getattr(layer, "cache_quant_type_str", "none"),
272-
layer.use_neox_rotary_style,
273-
self.rope_3d,
274-
self.max_seq_len,
275-
getattr(layer, "quant_max_bound", 0.0),
276-
getattr(layer, "quant_min_bound", 0.0),
277-
getattr(layer, "out_scale", -1.0),
278-
self.encoder_block_shape_q,
279-
self.decoder_block_shape_q,
280-
metadata.max_partition_size,
281-
metadata.encoder_max_partition_size,
282-
self.speculate_max_draft_token_num + 1,
283-
self.causal,
284-
self.speculative_method is not None,
285-
)[0]
234+
if self.use_output:
235+
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
236+
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")
237+
compute_type = metadata._fuse_kernel_compute_dtype
238+
out_scale = getattr(layer, "out_scale", -1.0)
239+
# 1. get output datatype
240+
qkv_dtype = qkv.dtype
241+
if qkv_dtype == paddle.float16:
242+
D_type = paddle.float16
243+
elif qkv_dtype == paddle.bfloat16:
244+
D_type = paddle.bfloat16
245+
elif qkv_dtype == paddle.int32:
246+
if compute_type == "bf16":
247+
D_type = paddle.bfloat16
248+
elif compute_type == "fp16":
249+
D_type = paddle.float16
250+
else:
251+
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].")
252+
else:
253+
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].")
254+
# 2.Extract related parameters
255+
token_nums = qkv.shape[0]
256+
head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2
257+
q_num_heads = self.num_heads
258+
# 3. generate output tensor of different dtypes
259+
if out_scale > 0.0:
260+
if abs(quant_max_bound - 127) < 0.000001:
261+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
262+
elif abs(quant_max_bound - 448) < 0.000001:
263+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
264+
else:
265+
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
266+
else:
267+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)
268+
269+
append_attention_with_output(
270+
qkv,
271+
forward_meta.caches[2 * layer.layer_id],
272+
forward_meta.caches[2 * layer.layer_id + 1],
273+
forward_meta.seq_lens_encoder,
274+
forward_meta.seq_lens_decoder,
275+
forward_meta.seq_lens_this_time,
276+
forward_meta.batch_id_per_token,
277+
forward_meta.cu_seqlens_q,
278+
metadata.block_tables,
279+
metadata.encoder_batch_ids,
280+
metadata.encoder_tile_ids_per_batch,
281+
metadata.encoder_num_blocks,
282+
metadata.kv_batch_ids,
283+
metadata.kv_tile_ids_per_batch,
284+
metadata.kv_num_blocks,
285+
forward_meta.decoder_batch_ids,
286+
forward_meta.decoder_tile_ids_per_batch,
287+
forward_meta.decoder_num_blocks_cpu,
288+
forward_meta.max_len_tensor_cpu,
289+
metadata.max_len_kv,
290+
res,
291+
metadata.rotary_embs,
292+
metadata.attn_mask,
293+
layer.qkv_bias,
294+
layer.qkv_scale,
295+
getattr(layer, "cache_k_scale", None),
296+
getattr(layer, "cache_v_scale", None),
297+
getattr(layer, "cache_k_out_scale", None),
298+
getattr(layer, "cache_v_out_scale", None),
299+
getattr(layer, "cache_k_zp", None),
300+
getattr(layer, "cache_v_zp", None),
301+
layer.linear_shift,
302+
layer.linear_smooth,
303+
metadata.mask_offset,
304+
metadata.kv_signal_data_list[layer.layer_id],
305+
getattr(layer, "q_norm_weight", None),
306+
getattr(layer, "k_norm_weight", None),
307+
getattr(layer, "rms_norm_eps", 1e-6),
308+
metadata._fuse_kernel_compute_dtype,
309+
getattr(layer, "cache_quant_type_str", "none"),
310+
layer.use_neox_rotary_style,
311+
self.rope_3d,
312+
self.max_seq_len,
313+
getattr(layer, "quant_max_bound", 0.0),
314+
getattr(layer, "quant_min_bound", 0.0),
315+
getattr(layer, "out_scale", -1.0),
316+
self.encoder_block_shape_q,
317+
self.decoder_block_shape_q,
318+
metadata.max_partition_size,
319+
metadata.encoder_max_partition_size,
320+
self.speculate_max_draft_token_num + 1,
321+
self.causal,
322+
self.speculative_method is not None,
323+
)
324+
else:
325+
res = append_attention(
326+
qkv,
327+
forward_meta.caches[2 * layer.layer_id],
328+
forward_meta.caches[2 * layer.layer_id + 1],
329+
forward_meta.seq_lens_encoder,
330+
forward_meta.seq_lens_decoder,
331+
forward_meta.seq_lens_this_time,
332+
forward_meta.batch_id_per_token,
333+
forward_meta.cu_seqlens_q,
334+
metadata.block_tables,
335+
metadata.encoder_batch_ids,
336+
metadata.encoder_tile_ids_per_batch,
337+
metadata.encoder_num_blocks,
338+
metadata.kv_batch_ids,
339+
metadata.kv_tile_ids_per_batch,
340+
metadata.kv_num_blocks,
341+
forward_meta.decoder_batch_ids,
342+
forward_meta.decoder_tile_ids_per_batch,
343+
forward_meta.decoder_num_blocks_cpu,
344+
forward_meta.max_len_tensor_cpu,
345+
metadata.max_len_kv,
346+
metadata.rotary_embs,
347+
metadata.attn_mask,
348+
layer.qkv_bias,
349+
layer.qkv_scale,
350+
getattr(layer, "cache_k_scale", None),
351+
getattr(layer, "cache_v_scale", None),
352+
getattr(layer, "cache_k_out_scale", None),
353+
getattr(layer, "cache_v_out_scale", None),
354+
getattr(layer, "cache_k_zp", None),
355+
getattr(layer, "cache_v_zp", None),
356+
layer.linear_shift,
357+
layer.linear_smooth,
358+
metadata.mask_offset,
359+
metadata.kv_signal_data_list[layer.layer_id],
360+
getattr(layer, "q_norm_weight", None),
361+
getattr(layer, "k_norm_weight", None),
362+
getattr(layer, "rms_norm_eps", 1e-6),
363+
metadata._fuse_kernel_compute_dtype,
364+
getattr(layer, "cache_quant_type_str", "none"),
365+
layer.use_neox_rotary_style,
366+
self.rope_3d,
367+
self.max_seq_len,
368+
getattr(layer, "quant_max_bound", 0.0),
369+
getattr(layer, "quant_min_bound", 0.0),
370+
getattr(layer, "out_scale", -1.0),
371+
self.encoder_block_shape_q,
372+
self.decoder_block_shape_q,
373+
metadata.max_partition_size,
374+
metadata.encoder_max_partition_size,
375+
self.speculate_max_draft_token_num + 1,
376+
self.causal,
377+
self.speculative_method is not None,
378+
)
286379
return res

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def forward_mixed(
378378
self.speculate_max_draft_token_num + 1,
379379
self.causal,
380380
self.speculative_method is not None,
381-
)[0]
381+
)
382382

383383
if metadata.max_len_tensor_cpu[1] > 0:
384384
merge_prefill_decode_output(

fastdeploy/model_executor/layers/attention/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17-
from .append_attention import append_attention
17+
from .append_attention import append_attention, append_attention_with_output
1818
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
1919
from .gqa_rope_write_cache import gqa_rope_write_cache
2020
from .init_kv_signal_per_query import init_kv_signal_per_query
@@ -25,6 +25,7 @@
2525
__all__ = [
2626
"get_block_shape_and_split_kv_block",
2727
"append_attention",
28+
"append_attention_with_output",
2829
"open_shm_and_get_meta_signal",
2930
"init_signal_layerwise",
3031
"gqa_rope_write_cache",

0 commit comments

Comments
 (0)