-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[https://nvbugs/5498478][fix] Fix eagle3 fp8 kv target model + bf16 draft model + chunked prefill #7805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[https://nvbugs/5498478][fix] Fix eagle3 fp8 kv target model + bf16 draft model + chunked prefill #7805
Changes from all commits
ee2ae57
378f298
6653094
4763096
b81a11a
b58d24a
d3fbc2a
e9e1fa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -318,6 +318,7 @@ def __init__( | |
|
|
||
| self.support_fused_qkv = self.attn.support_fused_qkv() | ||
| self.support_nvfp4_output = self.attn.support_nvfp4_output() | ||
| self.is_eagle3 = False | ||
|
|
||
| if not config.skip_create_weights_in_init: | ||
| self.create_weights() | ||
|
|
@@ -404,6 +405,10 @@ def _attn_impl( | |
| if mrope_position_deltas is not None: | ||
| mrope_config["mrope_position_deltas"] = mrope_position_deltas | ||
|
|
||
| # Be forced to use FP8 FMHA for BF16/FP16 model with FP8 KV cache (e.g. eagle3 + FP8 target model + BF16/FP16 draft model) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems too specific (more like a WAR). @yuxianq do you have any insights about this ? thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that it is too specific. The purpose of this PR is to add a way to explicitly control whether we use fp8 fmha outside attention op. How about add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that makes sense to me. Thanks! |
||
| fp8_fmha_for_eagle3 = self.is_eagle3 and not self.has_quant_scale and self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_kv_cache( | ||
| ) and attn_metadata.num_contexts != 0 | ||
|
|
||
| attn_output = self.attn.forward( | ||
| q, | ||
| k, | ||
|
|
@@ -420,7 +425,8 @@ def _attn_impl( | |
| enable_attn_nvfp4_output=enable_attn_nvfp4_output, | ||
| output=output[:num_tokens, :] if output is not None else None, | ||
| output_sf=output_sf, | ||
| attention_sinks=attention_sinks) | ||
| attention_sinks=attention_sinks, | ||
| fp8_fmha_for_eagle3=fp8_fmha_for_eagle3) | ||
| if isinstance(attn_output, tuple): | ||
| assert len( | ||
| attn_output | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we better add some comments here to describe the logic.