|
24 | 24 |
|
25 | 25 | from fastdeploy.model_executor.layers.attention.ops import (
|
26 | 26 | append_attention,
|
| 27 | + append_attention_with_output, |
27 | 28 | get_block_shape_and_split_kv_block,
|
28 | 29 | init_kv_signal_per_query,
|
29 | 30 | init_signal_layerwise,
|
@@ -121,6 +122,7 @@ def __init__(
|
121 | 122 | fd_config.parallel_config.expert_parallel_rank = 0
|
122 | 123 |
|
123 | 124 | self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
| 125 | + self.use_output = fd_config.graph_opt_config.full_cuda_graph |
124 | 126 |
|
125 | 127 | def init_attention_metadata(self, forward_meta: ForwardMeta):
|
126 | 128 | """Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
@@ -228,57 +230,147 @@ def forward_mixed(
|
228 | 230 | layer.layer_id + self.start_layer_index,
|
229 | 231 | )
|
230 | 232 |
|
231 |
| - res = append_attention( |
232 |
| - qkv, |
233 |
| - forward_meta.caches[2 * layer.layer_id], |
234 |
| - forward_meta.caches[2 * layer.layer_id + 1], |
235 |
| - forward_meta.seq_lens_encoder, |
236 |
| - forward_meta.seq_lens_decoder, |
237 |
| - forward_meta.seq_lens_this_time, |
238 |
| - forward_meta.batch_id_per_token, |
239 |
| - forward_meta.cu_seqlens_q, |
240 |
| - metadata.block_tables, |
241 |
| - metadata.encoder_batch_ids, |
242 |
| - metadata.encoder_tile_ids_per_batch, |
243 |
| - metadata.encoder_num_blocks, |
244 |
| - metadata.kv_batch_ids, |
245 |
| - metadata.kv_tile_ids_per_batch, |
246 |
| - metadata.kv_num_blocks, |
247 |
| - forward_meta.decoder_batch_ids, |
248 |
| - forward_meta.decoder_tile_ids_per_batch, |
249 |
| - forward_meta.decoder_num_blocks_cpu, |
250 |
| - forward_meta.max_len_tensor_cpu, |
251 |
| - metadata.max_len_kv, |
252 |
| - metadata.rotary_embs, |
253 |
| - metadata.attn_mask, |
254 |
| - layer.qkv_bias, |
255 |
| - layer.qkv_scale, |
256 |
| - getattr(layer, "cache_k_scale", None), |
257 |
| - getattr(layer, "cache_v_scale", None), |
258 |
| - getattr(layer, "cache_k_out_scale", None), |
259 |
| - getattr(layer, "cache_v_out_scale", None), |
260 |
| - getattr(layer, "cache_k_zp", None), |
261 |
| - getattr(layer, "cache_v_zp", None), |
262 |
| - layer.linear_shift, |
263 |
| - layer.linear_smooth, |
264 |
| - metadata.kv_signal_data_list[layer.layer_id], |
265 |
| - getattr(layer, "q_norm_weight", None), |
266 |
| - getattr(layer, "k_norm_weight", None), |
267 |
| - getattr(layer, "rms_norm_eps", 1e-6), |
268 |
| - metadata._fuse_kernel_compute_dtype, |
269 |
| - getattr(layer, "cache_quant_type_str", "none"), |
270 |
| - layer.use_neox_rotary_style, |
271 |
| - self.rope_3d, |
272 |
| - self.max_seq_len, |
273 |
| - getattr(layer, "quant_max_bound", 0.0), |
274 |
| - getattr(layer, "quant_min_bound", 0.0), |
275 |
| - getattr(layer, "out_scale", -1.0), |
276 |
| - self.encoder_block_shape_q, |
277 |
| - self.decoder_block_shape_q, |
278 |
| - metadata.max_partition_size, |
279 |
| - metadata.encoder_max_partition_size, |
280 |
| - self.speculate_max_draft_token_num + 1, |
281 |
| - self.causal, |
282 |
| - self.speculative_method is not None, |
283 |
| - )[0] |
| 233 | + if self.use_output: |
| 234 | + quant_max_bound = getattr(layer, "quant_max_bound", 0.0) |
| 235 | + cache_quant_type = getattr(layer, "cache_quant_type_str", "none") |
| 236 | + compute_type = metadata._fuse_kernel_compute_dtype |
| 237 | + out_scale = getattr(layer, "out_scale", -1.0) |
| 238 | + # 1. get output datatype |
| 239 | + qkv_dtype = qkv.dtype |
| 240 | + if qkv_dtype == paddle.float16: |
| 241 | + D_type = paddle.float16 |
| 242 | + elif qkv_dtype == paddle.bfloat16: |
| 243 | + D_type = paddle.bfloat16 |
| 244 | + elif qkv_dtype == paddle.int32: |
| 245 | + if compute_type == "bf16": |
| 246 | + D_type = paddle.bfloat16 |
| 247 | + elif compute_type == "fp16": |
| 248 | + D_type = paddle.float16 |
| 249 | + else: |
| 250 | + raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].") |
| 251 | + else: |
| 252 | + raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].") |
| 253 | + # 2.Extract related parameters |
| 254 | + token_nums = qkv.shape[0] |
| 255 | + head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2 |
| 256 | + q_num_heads = self.num_heads |
| 257 | + # 3. generate output tensor of different dtypes |
| 258 | + if out_scale > 0.0: |
| 259 | + if abs(quant_max_bound - 127) < 0.000001: |
| 260 | + res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place) |
| 261 | + elif abs(quant_max_bound - 448) < 0.000001: |
| 262 | + res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place) |
| 263 | + else: |
| 264 | + raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].") |
| 265 | + else: |
| 266 | + res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place) |
| 267 | + |
| 268 | + append_attention_with_output( |
| 269 | + qkv, |
| 270 | + forward_meta.caches[2 * layer.layer_id], |
| 271 | + forward_meta.caches[2 * layer.layer_id + 1], |
| 272 | + forward_meta.seq_lens_encoder, |
| 273 | + forward_meta.seq_lens_decoder, |
| 274 | + forward_meta.seq_lens_this_time, |
| 275 | + forward_meta.batch_id_per_token, |
| 276 | + forward_meta.cu_seqlens_q, |
| 277 | + metadata.block_tables, |
| 278 | + metadata.encoder_batch_ids, |
| 279 | + metadata.encoder_tile_ids_per_batch, |
| 280 | + metadata.encoder_num_blocks, |
| 281 | + metadata.kv_batch_ids, |
| 282 | + metadata.kv_tile_ids_per_batch, |
| 283 | + metadata.kv_num_blocks, |
| 284 | + forward_meta.decoder_batch_ids, |
| 285 | + forward_meta.decoder_tile_ids_per_batch, |
| 286 | + forward_meta.decoder_num_blocks_cpu, |
| 287 | + forward_meta.max_len_tensor_cpu, |
| 288 | + metadata.max_len_kv, |
| 289 | + res, |
| 290 | + metadata.rotary_embs, |
| 291 | + metadata.attn_mask, |
| 292 | + layer.qkv_bias, |
| 293 | + layer.qkv_scale, |
| 294 | + getattr(layer, "cache_k_scale", None), |
| 295 | + getattr(layer, "cache_v_scale", None), |
| 296 | + getattr(layer, "cache_k_out_scale", None), |
| 297 | + getattr(layer, "cache_v_out_scale", None), |
| 298 | + getattr(layer, "cache_k_zp", None), |
| 299 | + getattr(layer, "cache_v_zp", None), |
| 300 | + layer.linear_shift, |
| 301 | + layer.linear_smooth, |
| 302 | + metadata.kv_signal_data_list[layer.layer_id], |
| 303 | + getattr(layer, "q_norm_weight", None), |
| 304 | + getattr(layer, "k_norm_weight", None), |
| 305 | + getattr(layer, "rms_norm_eps", 1e-6), |
| 306 | + metadata._fuse_kernel_compute_dtype, |
| 307 | + getattr(layer, "cache_quant_type_str", "none"), |
| 308 | + layer.use_neox_rotary_style, |
| 309 | + self.rope_3d, |
| 310 | + self.max_seq_len, |
| 311 | + getattr(layer, "quant_max_bound", 0.0), |
| 312 | + getattr(layer, "quant_min_bound", 0.0), |
| 313 | + getattr(layer, "out_scale", -1.0), |
| 314 | + self.encoder_block_shape_q, |
| 315 | + self.decoder_block_shape_q, |
| 316 | + metadata.max_partition_size, |
| 317 | + metadata.encoder_max_partition_size, |
| 318 | + self.speculate_max_draft_token_num + 1, |
| 319 | + self.causal, |
| 320 | + self.speculative_method is not None, |
| 321 | + ) |
| 322 | + else: |
| 323 | + res = append_attention( |
| 324 | + qkv, |
| 325 | + forward_meta.caches[2 * layer.layer_id], |
| 326 | + forward_meta.caches[2 * layer.layer_id + 1], |
| 327 | + forward_meta.seq_lens_encoder, |
| 328 | + forward_meta.seq_lens_decoder, |
| 329 | + forward_meta.seq_lens_this_time, |
| 330 | + forward_meta.batch_id_per_token, |
| 331 | + forward_meta.cu_seqlens_q, |
| 332 | + metadata.block_tables, |
| 333 | + metadata.encoder_batch_ids, |
| 334 | + metadata.encoder_tile_ids_per_batch, |
| 335 | + metadata.encoder_num_blocks, |
| 336 | + metadata.kv_batch_ids, |
| 337 | + metadata.kv_tile_ids_per_batch, |
| 338 | + metadata.kv_num_blocks, |
| 339 | + forward_meta.decoder_batch_ids, |
| 340 | + forward_meta.decoder_tile_ids_per_batch, |
| 341 | + forward_meta.decoder_num_blocks_cpu, |
| 342 | + forward_meta.max_len_tensor_cpu, |
| 343 | + metadata.max_len_kv, |
| 344 | + metadata.rotary_embs, |
| 345 | + metadata.attn_mask, |
| 346 | + layer.qkv_bias, |
| 347 | + layer.qkv_scale, |
| 348 | + getattr(layer, "cache_k_scale", None), |
| 349 | + getattr(layer, "cache_v_scale", None), |
| 350 | + getattr(layer, "cache_k_out_scale", None), |
| 351 | + getattr(layer, "cache_v_out_scale", None), |
| 352 | + getattr(layer, "cache_k_zp", None), |
| 353 | + getattr(layer, "cache_v_zp", None), |
| 354 | + layer.linear_shift, |
| 355 | + layer.linear_smooth, |
| 356 | + metadata.kv_signal_data_list[layer.layer_id], |
| 357 | + getattr(layer, "q_norm_weight", None), |
| 358 | + getattr(layer, "k_norm_weight", None), |
| 359 | + getattr(layer, "rms_norm_eps", 1e-6), |
| 360 | + metadata._fuse_kernel_compute_dtype, |
| 361 | + getattr(layer, "cache_quant_type_str", "none"), |
| 362 | + layer.use_neox_rotary_style, |
| 363 | + self.rope_3d, |
| 364 | + self.max_seq_len, |
| 365 | + getattr(layer, "quant_max_bound", 0.0), |
| 366 | + getattr(layer, "quant_min_bound", 0.0), |
| 367 | + getattr(layer, "out_scale", -1.0), |
| 368 | + self.encoder_block_shape_q, |
| 369 | + self.decoder_block_shape_q, |
| 370 | + metadata.max_partition_size, |
| 371 | + metadata.encoder_max_partition_size, |
| 372 | + self.speculate_max_draft_token_num + 1, |
| 373 | + self.causal, |
| 374 | + self.speculative_method is not None, |
| 375 | + ) |
284 | 376 | return res
|
0 commit comments