Skip to content

Commit 8572b8a

Browse files
committed
get use_output from fd_config
1 parent c011cb8 commit 8572b8a

File tree

8 files changed

+1334
-94
lines changed

8 files changed

+1334
-94
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 397 additions & 38 deletions
Large diffs are not rendered by default.

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,48 @@ std::vector<paddle::Tensor> AppendAttention(
9090
const int speculate_max_draft_token_num, const bool causal,
9191
const bool speculate_decoder);
9292

93+
void AppendAttentionWithOutput(
94+
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
95+
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
96+
const paddle::Tensor &seq_lens_decoder,
97+
const paddle::Tensor &seq_lens_this_time,
98+
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
99+
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
100+
const paddle::Tensor &encoder_tile_ids_per_batch,
101+
const paddle::Tensor &encoder_num_blocks,
102+
const paddle::Tensor &kv_batch_ids,
103+
const paddle::Tensor &kv_tile_ids_per_batch,
104+
const paddle::Tensor &kv_num_blocks,
105+
const paddle::Tensor &decoder_batch_ids,
106+
const paddle::Tensor &decoder_tile_ids_per_batch,
107+
const paddle::Tensor &decoder_num_blocks,
108+
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
109+
paddle::Tensor &res,
110+
const paddle::optional<paddle::Tensor> &rotary_embs,
111+
const paddle::optional<paddle::Tensor> &attn_mask,
112+
const paddle::optional<paddle::Tensor> &qkv_bias,
113+
const paddle::optional<paddle::Tensor> &qkv_out_scales,
114+
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
115+
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
116+
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
117+
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
118+
const paddle::optional<paddle::Tensor> &cache_k_zp,
119+
const paddle::optional<paddle::Tensor> &cache_v_zp,
120+
const paddle::optional<paddle::Tensor> &out_linear_shifts,
121+
const paddle::optional<paddle::Tensor> &out_linear_smooths,
122+
const paddle::optional<paddle::Tensor> &kv_signal_data,
123+
const paddle::optional<paddle::Tensor>& q_norm_weight,
124+
const paddle::optional<paddle::Tensor>& k_norm_weight,
125+
const float rms_norm_eps,
126+
const std::string &compute_dtype, const std::string &cache_quant_type_str,
127+
const bool use_neox_rotary_style, const bool rope_3d,
128+
const int max_input_length, const float quant_max_bound,
129+
const float quant_min_bound, const float out_linear_in_scale,
130+
const int encoder_block_shape_q, const int decoder_block_shape_q,
131+
const int max_partition_size, const int encoder_max_partition_size,
132+
const int speculate_max_draft_token_num, const bool causal,
133+
const bool speculate_decoder);
134+
93135
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
94136
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
95137
const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q,
@@ -828,6 +870,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
828870
* append_attention
829871
*/
830872
m.def("append_attention", &AppendAttention, "append attention function");
873+
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
831874
/**
832875
* gqa_rope_write_cache.cu
833876
* gqa_rope_write_cache

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 145 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from fastdeploy.model_executor.layers.attention.ops import (
2626
append_attention,
27+
append_attention_with_output,
2728
get_block_shape_and_split_kv_block,
2829
init_kv_signal_per_query,
2930
init_signal_layerwise,
@@ -121,6 +122,7 @@ def __init__(
121122
fd_config.parallel_config.expert_parallel_rank = 0
122123

123124
self.rank, self.device_id = init_rank_and_device_id(fd_config)
125+
self.use_output = fd_config.graph_opt_config.full_cuda_graph
124126

125127
def init_attention_metadata(self, forward_meta: ForwardMeta):
126128
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
@@ -228,57 +230,147 @@ def forward_mixed(
228230
layer.layer_id + self.start_layer_index,
229231
)
230232

231-
res = append_attention(
232-
qkv,
233-
forward_meta.caches[2 * layer.layer_id],
234-
forward_meta.caches[2 * layer.layer_id + 1],
235-
forward_meta.seq_lens_encoder,
236-
forward_meta.seq_lens_decoder,
237-
forward_meta.seq_lens_this_time,
238-
forward_meta.batch_id_per_token,
239-
forward_meta.cu_seqlens_q,
240-
metadata.block_tables,
241-
metadata.encoder_batch_ids,
242-
metadata.encoder_tile_ids_per_batch,
243-
metadata.encoder_num_blocks,
244-
metadata.kv_batch_ids,
245-
metadata.kv_tile_ids_per_batch,
246-
metadata.kv_num_blocks,
247-
forward_meta.decoder_batch_ids,
248-
forward_meta.decoder_tile_ids_per_batch,
249-
forward_meta.decoder_num_blocks_cpu,
250-
forward_meta.max_len_tensor_cpu,
251-
metadata.max_len_kv,
252-
metadata.rotary_embs,
253-
metadata.attn_mask,
254-
layer.qkv_bias,
255-
layer.qkv_scale,
256-
getattr(layer, "cache_k_scale", None),
257-
getattr(layer, "cache_v_scale", None),
258-
getattr(layer, "cache_k_out_scale", None),
259-
getattr(layer, "cache_v_out_scale", None),
260-
getattr(layer, "cache_k_zp", None),
261-
getattr(layer, "cache_v_zp", None),
262-
layer.linear_shift,
263-
layer.linear_smooth,
264-
metadata.kv_signal_data_list[layer.layer_id],
265-
getattr(layer, "q_norm_weight", None),
266-
getattr(layer, "k_norm_weight", None),
267-
getattr(layer, "rms_norm_eps", 1e-6),
268-
metadata._fuse_kernel_compute_dtype,
269-
getattr(layer, "cache_quant_type_str", "none"),
270-
layer.use_neox_rotary_style,
271-
self.rope_3d,
272-
self.max_seq_len,
273-
getattr(layer, "quant_max_bound", 0.0),
274-
getattr(layer, "quant_min_bound", 0.0),
275-
getattr(layer, "out_scale", -1.0),
276-
self.encoder_block_shape_q,
277-
self.decoder_block_shape_q,
278-
metadata.max_partition_size,
279-
metadata.encoder_max_partition_size,
280-
self.speculate_max_draft_token_num + 1,
281-
self.causal,
282-
self.speculative_method is not None,
283-
)[0]
233+
if self.use_output:
234+
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
235+
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")
236+
compute_type = metadata._fuse_kernel_compute_dtype
237+
out_scale = getattr(layer, "out_scale", -1.0)
238+
# 1. get output datatype
239+
qkv_dtype = qkv.dtype
240+
if qkv_dtype == paddle.float16:
241+
D_type = paddle.float16
242+
elif qkv_dtype == paddle.bfloat16:
243+
D_type = paddle.bfloat16
244+
elif qkv_dtype == paddle.int32:
245+
if compute_type == "bf16":
246+
D_type = paddle.bfloat16
247+
elif compute_type == "fp16":
248+
D_type = paddle.float16
249+
else:
250+
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].")
251+
else:
252+
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].")
253+
# 2.Extract related parameters
254+
token_nums = qkv.shape[0]
255+
head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2
256+
q_num_heads = self.num_heads
257+
# 3. generate output tensor of different dtypes
258+
if out_scale > 0.0:
259+
if abs(quant_max_bound - 127) < 0.000001:
260+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
261+
elif abs(quant_max_bound - 448) < 0.000001:
262+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
263+
else:
264+
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
265+
else:
266+
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)
267+
268+
append_attention_with_output(
269+
qkv,
270+
forward_meta.caches[2 * layer.layer_id],
271+
forward_meta.caches[2 * layer.layer_id + 1],
272+
forward_meta.seq_lens_encoder,
273+
forward_meta.seq_lens_decoder,
274+
forward_meta.seq_lens_this_time,
275+
forward_meta.batch_id_per_token,
276+
forward_meta.cu_seqlens_q,
277+
metadata.block_tables,
278+
metadata.encoder_batch_ids,
279+
metadata.encoder_tile_ids_per_batch,
280+
metadata.encoder_num_blocks,
281+
metadata.kv_batch_ids,
282+
metadata.kv_tile_ids_per_batch,
283+
metadata.kv_num_blocks,
284+
forward_meta.decoder_batch_ids,
285+
forward_meta.decoder_tile_ids_per_batch,
286+
forward_meta.decoder_num_blocks_cpu,
287+
forward_meta.max_len_tensor_cpu,
288+
metadata.max_len_kv,
289+
res,
290+
metadata.rotary_embs,
291+
metadata.attn_mask,
292+
layer.qkv_bias,
293+
layer.qkv_scale,
294+
getattr(layer, "cache_k_scale", None),
295+
getattr(layer, "cache_v_scale", None),
296+
getattr(layer, "cache_k_out_scale", None),
297+
getattr(layer, "cache_v_out_scale", None),
298+
getattr(layer, "cache_k_zp", None),
299+
getattr(layer, "cache_v_zp", None),
300+
layer.linear_shift,
301+
layer.linear_smooth,
302+
metadata.kv_signal_data_list[layer.layer_id],
303+
getattr(layer, "q_norm_weight", None),
304+
getattr(layer, "k_norm_weight", None),
305+
getattr(layer, "rms_norm_eps", 1e-6),
306+
metadata._fuse_kernel_compute_dtype,
307+
getattr(layer, "cache_quant_type_str", "none"),
308+
layer.use_neox_rotary_style,
309+
self.rope_3d,
310+
self.max_seq_len,
311+
getattr(layer, "quant_max_bound", 0.0),
312+
getattr(layer, "quant_min_bound", 0.0),
313+
getattr(layer, "out_scale", -1.0),
314+
self.encoder_block_shape_q,
315+
self.decoder_block_shape_q,
316+
metadata.max_partition_size,
317+
metadata.encoder_max_partition_size,
318+
self.speculate_max_draft_token_num + 1,
319+
self.causal,
320+
self.speculative_method is not None,
321+
)
322+
else:
323+
res = append_attention(
324+
qkv,
325+
forward_meta.caches[2 * layer.layer_id],
326+
forward_meta.caches[2 * layer.layer_id + 1],
327+
forward_meta.seq_lens_encoder,
328+
forward_meta.seq_lens_decoder,
329+
forward_meta.seq_lens_this_time,
330+
forward_meta.batch_id_per_token,
331+
forward_meta.cu_seqlens_q,
332+
metadata.block_tables,
333+
metadata.encoder_batch_ids,
334+
metadata.encoder_tile_ids_per_batch,
335+
metadata.encoder_num_blocks,
336+
metadata.kv_batch_ids,
337+
metadata.kv_tile_ids_per_batch,
338+
metadata.kv_num_blocks,
339+
forward_meta.decoder_batch_ids,
340+
forward_meta.decoder_tile_ids_per_batch,
341+
forward_meta.decoder_num_blocks_cpu,
342+
forward_meta.max_len_tensor_cpu,
343+
metadata.max_len_kv,
344+
metadata.rotary_embs,
345+
metadata.attn_mask,
346+
layer.qkv_bias,
347+
layer.qkv_scale,
348+
getattr(layer, "cache_k_scale", None),
349+
getattr(layer, "cache_v_scale", None),
350+
getattr(layer, "cache_k_out_scale", None),
351+
getattr(layer, "cache_v_out_scale", None),
352+
getattr(layer, "cache_k_zp", None),
353+
getattr(layer, "cache_v_zp", None),
354+
layer.linear_shift,
355+
layer.linear_smooth,
356+
metadata.kv_signal_data_list[layer.layer_id],
357+
getattr(layer, "q_norm_weight", None),
358+
getattr(layer, "k_norm_weight", None),
359+
getattr(layer, "rms_norm_eps", 1e-6),
360+
metadata._fuse_kernel_compute_dtype,
361+
getattr(layer, "cache_quant_type_str", "none"),
362+
layer.use_neox_rotary_style,
363+
self.rope_3d,
364+
self.max_seq_len,
365+
getattr(layer, "quant_max_bound", 0.0),
366+
getattr(layer, "quant_min_bound", 0.0),
367+
getattr(layer, "out_scale", -1.0),
368+
self.encoder_block_shape_q,
369+
self.decoder_block_shape_q,
370+
metadata.max_partition_size,
371+
metadata.encoder_max_partition_size,
372+
self.speculate_max_draft_token_num + 1,
373+
self.causal,
374+
self.speculative_method is not None,
375+
)
284376
return res

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def forward_mixed(
377377
self.speculate_max_draft_token_num + 1,
378378
self.causal,
379379
self.speculative_method is not None,
380-
)[0]
380+
)
381381

382382
if metadata.max_len_tensor_cpu[1] > 0:
383383
merge_prefill_decode_output(

fastdeploy/model_executor/layers/attention/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17-
from .append_attention import append_attention
17+
from .append_attention import append_attention, append_attention_with_output
1818
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
1919
from .gqa_rope_write_cache import gqa_rope_write_cache
2020
from .init_kv_signal_per_query import init_kv_signal_per_query
@@ -25,6 +25,7 @@
2525
__all__ = [
2626
"get_block_shape_and_split_kv_block",
2727
"append_attention",
28+
"append_attention_with_output",
2829
"open_shm_and_get_meta_signal",
2930
"init_signal_layerwise",
3031
"gqa_rope_write_cache",

0 commit comments

Comments
 (0)