@@ -132,7 +132,7 @@ def _bind_attention(self):
132132 LlamaTransformerLayerInfer ._token_decode_attention_flashinfer_fp8 , self
133133 )
134134 else :
135- raise Exception ("fp8 kvcache only support fa3 and flashinfer backend" )
135+ raise Exception ("calibration fp8 kvcache only support fa3 and flashinfer backend" )
136136 elif "triton_flashdecoding" in self .mode :
137137 self ._token_attention_kernel = partial (
138138 LlamaTransformerLayerInfer ._token_decode_attention_flashdecoding , self
@@ -333,6 +333,13 @@ def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionSt
333333 def _context_attention_flashattention_fp8 (
334334 self , q , kv , infer_state : FlashAttentionStateInfo , layer_weight , out = None
335335 ):
336+ q , q_scale = q_per_head_fp8_quant (
337+ q .view (q .shape [0 ], self .tp_k_head_num_ , - 1 ),
338+ infer_state .b_seq_len ,
339+ infer_state .cu_seqlens_q ,
340+ infer_state .q_scale ,
341+ infer_state .batch_ids ,
342+ )
336343 cache_k = (
337344 (infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, : self .tp_k_head_num_ , :])
338345 .reshape (- 1 , 1 , self .tp_k_head_num_ , self .head_dim_ )
@@ -347,43 +354,21 @@ def _context_attention_flashattention_fp8(
347354 .reshape (- 1 , 1 , self .tp_v_head_num_ , self .head_dim_ )
348355 .view (torch .float8_e4m3fn )
349356 )
350- q , q_scale = q_per_head_fp8_quant (
351- q .view (q .shape [0 ], self .tp_k_head_num_ , - 1 ),
352- infer_state .b_seq_len ,
353- infer_state .cu_seqlens_q ,
354- )
355- q = q .view (- 1 , self .tp_q_head_num_ , self .head_dim_ )
356- q_descale = q_scale
357- ones_scales = torch .ones ((infer_state .batch_size , self .tp_k_head_num_ ), device = q .device , dtype = torch .float32 )
358- offline_scales = infer_state .mem_manager .offline_fp8_quant_manager .scales
359- k_descale = (
360- offline_scales [self .layer_num_ ][: self .tp_k_head_num_ ].expand (infer_state .batch_size , self .tp_k_head_num_ )
361- if offline_scales is not None
362- else ones_scales
363- )
364- v_descale = (
365- offline_scales [self .layer_num_ ][self .tp_k_head_num_ :].expand (infer_state .batch_size , self .tp_k_head_num_ )
366- if offline_scales is not None
367- else ones_scales
368- )
369- Lq = q .shape [- 1 ]
370- sm_scale = 1.0 / (Lq ** 0.5 )
371357 o = flash_attn_with_kvcache (
372- q = q ,
358+ q = q . view ( - 1 , self . tp_q_head_num_ , self . head_dim_ ) ,
373359 k_cache = cache_k ,
374360 v_cache = cache_v ,
375361 page_table = infer_state .page_table ,
376362 cache_seqlens = infer_state .b_seq_len ,
377363 cu_seqlens_q = infer_state .cu_seqlens_q ,
378364 cu_seqlens_k_new = infer_state .cu_seqlens_k ,
379365 max_seqlen_q = infer_state .q_max_seq_len ,
380- softmax_scale = sm_scale ,
381366 causal = True ,
382367 window_size = (- 1 , - 1 ),
383368 softcap = 0.0 ,
384- q_descale = q_descale ,
385- k_descale = k_descale ,
386- v_descale = v_descale ,
369+ q_descale = q_scale ,
370+ k_descale = infer_state . k_descale [ self . layer_num_ ] ,
371+ v_descale = infer_state . v_descale [ self . layer_num_ ] ,
387372 return_softmax_lse = False ,
388373 )
389374 return o
@@ -867,38 +852,21 @@ def _token_decode_attention_flashattention_fp8(
867852 .view (torch .float8_e4m3fn )
868853 )
869854 q , q_scale = scaled_fp8_quant (q .view (q .shape [0 ] * self .tp_k_head_num_ , - 1 ), use_per_token_if_dynamic = True )
870- q = q .view (- 1 , self .tp_q_head_num_ , self .head_dim_ )
871- q_descale = q_scale .view (q .shape [0 ], self .tp_k_head_num_ )
872- ones_scales = torch .ones ((infer_state .batch_size , self .tp_k_head_num_ ), device = q .device , dtype = torch .float32 )
873- offline_scales = infer_state .mem_manager .offline_fp8_quant_manager .scales
874- k_descale = (
875- offline_scales [self .layer_num_ ][: self .tp_k_head_num_ ].expand (infer_state .batch_size , self .tp_k_head_num_ )
876- if offline_scales is not None
877- else ones_scales
878- )
879- v_descale = (
880- offline_scales [self .layer_num_ ][self .tp_k_head_num_ :].expand (infer_state .batch_size , self .tp_k_head_num_ )
881- if offline_scales is not None
882- else ones_scales
883- )
884- Lq = q .shape [- 1 ]
885- sm_scale = 1.0 / (Lq ** 0.5 )
886855 o = flash_attn_with_kvcache (
887- q = q ,
856+ q = q . view ( - 1 , self . tp_q_head_num_ , self . head_dim_ ) ,
888857 k_cache = cache_k ,
889858 v_cache = cache_v ,
890859 page_table = infer_state .page_table ,
891860 cache_seqlens = infer_state .b_seq_len ,
892861 cu_seqlens_q = infer_state .cu_seqlens_q ,
893862 cu_seqlens_k_new = infer_state .cu_seqlens_k ,
894863 max_seqlen_q = 1 ,
895- softmax_scale = sm_scale ,
896864 causal = False ,
897865 window_size = (- 1 , - 1 ),
898866 softcap = 0.0 ,
899- q_descale = q_descale ,
900- k_descale = k_descale ,
901- v_descale = v_descale ,
867+ q_descale = q_scale . view ( infer_state . batch_size , self . tp_k_head_num_ ) ,
868+ k_descale = infer_state . k_descale [ self . layer_num_ ] ,
869+ v_descale = infer_state . v_descale [ self . layer_num_ ] ,
902870 return_softmax_lse = False ,
903871 )
904872 return o
0 commit comments