-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[bugfix] fix qwen3_5 fp8 gpt-bridge #8076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+30
−16
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
19fcdc9
fix qwen3_5 fp8
Jintao-Huang 4b10204
fix
Jintao-Huang a93c150
fix
Jintao-Huang 84c2387
Merge branch 'main' into fix_qwen3_5_fp8
Jintao-Huang aa24bd3
fix
Jintao-Huang a001ad7
fix
Jintao-Huang 0707639
update
Jintao-Huang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||
| import math | ||
| import megatron.core | ||
| import re | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.nn.functional as F | ||
|
|
@@ -736,11 +737,10 @@ def _set_moe_state( | |
| def _get_hf_grouped(self, is_mtp_layer: bool = False): | ||
| if self.model_type in { | ||
| 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', | ||
| 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe' | ||
| 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', | ||
| 'qwen3_5_moe' | ||
| }: | ||
| return False, False | ||
| elif self.model_type == 'qwen3_5_moe' and is_mtp_layer: | ||
| return False, False | ||
| return None, None | ||
|
|
||
| def _get_transpose(self): | ||
|
|
@@ -760,32 +760,46 @@ def _set_mlp_state( | |
| hf_mlp=None, | ||
| is_mtp_layer: bool = False, | ||
| ): | ||
| if to_mcore: | ||
| hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) | ||
| if hf_mlp is None: | ||
| hf_mlp = self._get_hf_mlp(layer_idx) | ||
| is_expert = ep_rank is not None | ||
| num_local_experts = 1 | ||
| hf_grouped = False | ||
| config = self.config | ||
| if is_expert: | ||
| hf_grouped = not hasattr(hf_mlp.experts, '__len__') | ||
| hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0] | ||
| hf_mlp = hf_mlp.experts | ||
| # When converting to_mcore, hf_grouped is determined by default from the hf_state_dict condition. | ||
| # When converting to_hf, it is determined by default from the hf_mlp condition. | ||
| if to_mcore: | ||
| pattern = r'\d+\.down_proj' | ||
| hf_grouped = not any(re.match(pattern, k) is not None for k in hf_state_dict.keys()) | ||
| else: | ||
| hf_grouped = not hasattr(hf_mlp, '__len__') | ||
| if hasattr(hf_mlp, '__len__'): | ||
| hf_mlp = hf_mlp[0] | ||
| num_local_experts = config.num_moe_experts // self.ep_size | ||
| is_gate_up = hasattr(hf_mlp, 'gate_up_proj') | ||
| if to_mcore: | ||
| is_gate_up = any('gate_up_proj' in k for k in hf_state_dict.keys()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| else: | ||
| is_gate_up = hasattr(hf_mlp, 'gate_up_proj') | ||
| # transformers 5.0 compatibility | ||
| if self.is_transformers_5: | ||
| if self.is_transformers_5 and not to_mcore and is_expert: | ||
| _hf_grouped, _is_gate_up = self._get_hf_grouped(is_mtp_layer) | ||
| if _hf_grouped is not None: | ||
| hf_grouped = _hf_grouped | ||
| if _is_gate_up is not None: | ||
| is_gate_up = _is_gate_up | ||
| need_transpose = True | ||
| if self.is_transformers_5: | ||
| if self.is_transformers_5 and hf_grouped: | ||
| need_transpose = self._get_transpose() | ||
|
|
||
| if to_mcore or hf_grouped: | ||
| if hf_grouped and not to_mcore: | ||
| hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) | ||
| else: | ||
| elif not to_mcore: | ||
| hf_state_dict = {} | ||
|
|
||
| # linear_fc1 | ||
| if to_mcore: | ||
| has_scale_inv = any('_scale_inv' in k for k in hf_state_dict.keys()) | ||
|
|
@@ -1623,7 +1637,7 @@ def save_weights(self, | |
| config = self.config | ||
| if config.mtp_num_layers: | ||
| hf_config.num_nextn_predict_layers = config.mtp_num_layers | ||
| if config.fp8 is not None and config.fp8_recipe == 'blockwise' and config.fp8_param_gather: | ||
| if config.fp8 is not None and config.fp8_recipe == 'blockwise' and config.fp8_param: | ||
| if getattr(hf_config, 'quantization_config', None) is None: | ||
| from transformers.utils.quantization_config import FineGrainedFP8Config | ||
| modules_to_not_convert = get_modules_to_not_convert(self.hf_model) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using regex to detect separate experts in the state dict is a robust way to determine
hf_groupedwhen loading weights. This is more reliable than relying on hardcoded model lists. You can simplify the generator expression slightly for better readability.References