-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[https://nvbugs/5498478][fix] Fix eagle3 fp8 kv target model + bf16 draft model + chunked prefill #7805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[https://nvbugs/5498478][fix] Fix eagle3 fp8 kv target model + bf16 draft model + chunked prefill #7805
Changes from 6 commits
ee2ae57
378f298
6653094
4763096
b81a11a
b58d24a
d3fbc2a
e9e1fa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1211,6 +1211,7 @@ def forward( | |
| output_sf: Optional[torch.Tensor] = None, | ||
| attention_sinks: Optional[torch.Tensor] = None, | ||
| chunked_prefill_buffer_batch_size: int = 1, | ||
| fp8_fmha_for_eagle3: bool = False, | ||
| **kwargs, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | ||
| assert isinstance( | ||
|
|
@@ -1293,6 +1294,10 @@ def forward( | |
| if use_nvfp4_output: | ||
| # Use UINT8 as the container dtype for NVFP4. | ||
| out_dtype = torch.uint8 | ||
| # elif fp8_fmha_for_eagle3: | ||
| elif self.has_fp8_kv_cache and not self.has_fp8_qdq and out_scale is not None: | ||
| # Force to use FP8 FMHA for (eagle3 + FP8 target model + BF16/FP16 draft model) in draft layers | ||
| out_dtype = torch.float8_e4m3fn | ||
|
||
| elif (self.has_fp8_qdq or self.has_nvfp4 or self.has_fp8_block_wise | ||
| or self.has_fp8_rowwise | ||
| or self.has_w4a8_nvfp4_fp8) and (self.has_fp8_kv_cache | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -404,6 +404,14 @@ def _attn_impl( | |
| if mrope_position_deltas is not None: | ||
| mrope_config["mrope_position_deltas"] = mrope_position_deltas | ||
|
|
||
| # Be forced to use FP8 FMHA for BF16/FP16 model with FP8 KV cache (e.g. eagle3 + FP8 target model + BF16/FP16 draft model) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems too specific (more like a WAR). @yuxianq do you have any insights about this ? thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that it is too specific. The purpose of this PR is to add a way to explicitly control whether we use fp8 fmha outside attention op. How about add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that makes sense to me. Thanks! |
||
| forced_to_fp8_fmha = not self.has_quant_scale and self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_kv_cache( | ||
|
||
| ) and attn_metadata.num_contexts != 0 | ||
| if forced_to_fp8_fmha: | ||
| out_scale = torch.tensor([1.0], | ||
| dtype=torch.float32, | ||
| device=q.device) | ||
|
|
||
| attn_output = self.attn.forward( | ||
| q, | ||
| k, | ||
|
|
@@ -425,7 +433,12 @@ def _attn_impl( | |
| assert len( | ||
| attn_output | ||
| ) == 2, "attn_output should be a tuple of (output, output_sf)" | ||
| return attn_output[0], attn_output[1] | ||
| if forced_to_fp8_fmha: | ||
| return attn_output[0].to(q.dtype), attn_output[1] | ||
| else: | ||
| return attn_output[0], attn_output[1] | ||
| if forced_to_fp8_fmha: | ||
| return attn_output.to(q.dtype), None | ||
| return attn_output, None | ||
|
|
||
| def forward_impl( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.