-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat(dsv3):Runnable N1C8 configs #2525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hushenwei2000
wants to merge
16
commits into
PaddlePaddle:develop
Choose a base branch
from
hushenwei2000:merge_dsv3_pr
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
4e79040
update
chen2016013 1b6c1f4
fix:(moe config): default using_flex_token
hushenwei2000 0905f8c
doc(comment): fix code comment
hushenwei2000 65bb06d
doc(comment): fix code comment
hushenwei2000 788e712
(1) move all updates into example folder (2) move DSV3_USE_FP8_GEMM D…
hushenwei2000 a733ccb
(1)recover bos download (2) move dsv3_fast_pretrain from env to arg (…
hushenwei2000 49663f1
code format
hushenwei2000 d0f203f
Merge branch 'develop' into merge_dsv3_pr
hushenwei2000 c4446cc
code format
hushenwei2000 0ff74b2
add fa_version in config; fix code
hushenwei2000 bf201bb
remove modeling_fast; move config into config file; format code
hushenwei2000 9164dbf
Merge branch 'develop' into merge_dsv3_pr
hushenwei2000 b881da9
code format
hushenwei2000 35d1b97
add old dataset
hushenwei2000 c99cbb7
use DeepseekV2PretrainingCriterionFast
hushenwei2000 e1b2801
1) replace "dsv3_fast_pretrain" with "reorder_pipeline_priority" 2) m…
hushenwei2000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import sys | ||
from typing import TYPE_CHECKING | ||
|
||
from paddleformers.utils.lazy_import import _LazyModule | ||
|
||
import_structure = { | ||
"configuration": ["DeepseekV2FastConfig"], | ||
"modeling": [ | ||
"masked_fill", | ||
"DeepseekV2Attention", | ||
"MoEGate", | ||
"FakeGate", | ||
"DeepseekV2ForCausalLM", | ||
"_make_causal_mask", | ||
"is_casual_mask", | ||
"DeepseekV2MoE", | ||
"DeepseekV2MoEFlexToken", | ||
"scaled_dot_product_attention", | ||
"DeepseekV2RotaryEmbedding", | ||
"rotate_half", | ||
"DeepseekV2MTPLayer", | ||
"DeepseekV2RMSNorm", | ||
"DeepseekV2YarnRotaryEmbedding", | ||
"parallel_matmul", | ||
"DeepseekV2PretrainedModel", | ||
"AddAuxiliaryLoss", | ||
"apply_rotary_pos_emb", | ||
"assign_kv_heads", | ||
"DeepseekV2ForSequenceClassification", | ||
"_expand_2d_mask", | ||
"DeepseekV2ModelFast", | ||
"repeat_kv", | ||
"yarn_find_correction_dim", | ||
"yarn_linear_ramp_mask", | ||
"DeepseekV2DynamicNTKScalingRotaryEmbedding", | ||
"DeepseekV2MLP", | ||
"yarn_get_mscale", | ||
"DeepseekV2LMHead", | ||
"DeepseekV2DecoderLayer", | ||
"DeepseekV2PretrainingCriterionFast", | ||
"yarn_find_correction_range", | ||
"get_triangle_upper_mask", | ||
"DeepseekV2LinearScalingRotaryEmbedding", | ||
"set_global_step", | ||
"get_global_step", | ||
], | ||
"modeling_auto": [ | ||
"DeepseekV2LMHeadAuto", | ||
"DeepseekV2ForCausalLMAuto", | ||
"DeepseekV2ModelAuto", | ||
"DeepseekV2PretrainedModelAuto", | ||
], | ||
"modeling_pp": ["DeepseekV2ForCausalLMPipe"], | ||
"mfu_utils": ["DeepSeekProjection"], | ||
"kernel": [ | ||
"act_quant", | ||
"weight_dequant", | ||
"fp8_gemm", | ||
"weight_dequant_kernel", | ||
"act_quant_kernel", | ||
"fp8_gemm_kernel", | ||
], | ||
"tokenizer_fast": ["DeepseekTokenizerFast"], | ||
"fp8_linear": [ | ||
"Linear", | ||
"ColumnParallelLinear", | ||
"RowParallelLinear", | ||
"ColumnSequenceParallelLinear", | ||
"RowSequenceParallelLinear", | ||
], | ||
} | ||
|
||
if TYPE_CHECKING: | ||
from .configuration import * | ||
from .modeling import * | ||
from .modeling_auto import * | ||
from .modeling_pp import * | ||
from .tokenizer_fast import * | ||
else: | ||
sys.modules[__name__] = _LazyModule( | ||
__name__, | ||
globals()["__file__"], | ||
import_structure, | ||
module_spec=__spec__, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
{ | ||
"architectures": [ | ||
"DeepseekV2ForCausalLM" | ||
], | ||
"attention_bias": false, | ||
"attention_dropout": 0.0, | ||
"auto_map": { | ||
"AutoConfig": "DeepseekV2FastConfig", | ||
"AutoModel": "DeepseekV2ModelFast", | ||
"AutoModelForCausalLM": "DeepseekV2ForCausalLM" | ||
}, | ||
"aux_loss_alpha": 0.001, | ||
"bos_token_id": 0, | ||
"eos_token_id": 1, | ||
"ep_size": 1, | ||
"first_k_dense_replace": 3, | ||
"hidden_act": "silu", | ||
"hidden_size": 7168, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 18432, | ||
"kv_lora_rank": 512, | ||
"max_position_embeddings": 163840, | ||
"model_type": "deepseek_v3", | ||
"moe_intermediate_size": 2048, | ||
"moe_layer_freq": 1, | ||
"n_group": 8, | ||
"n_routed_experts": 8, | ||
"n_shared_experts": 1, | ||
"norm_topk_prob": true, | ||
"num_attention_heads": 128, | ||
"num_experts_per_tok": 8, | ||
"num_hidden_layers": 15, | ||
"num_key_value_heads": 128, | ||
"num_nextn_predict_layers": 1, | ||
"pretraining_tp": 1, | ||
"q_lora_rank": 1536, | ||
"qk_nope_head_dim": 128, | ||
"qk_rope_head_dim": 64, | ||
"rms_norm_eps": 1e-06, | ||
"rope_scaling": { | ||
"beta_fast": 32, | ||
"beta_slow": 1, | ||
"factor": 40, | ||
"mscale": 1.0, | ||
"mscale_all_dim": 1.0, | ||
"original_max_position_embeddings": 4096, | ||
"type": "yarn" | ||
}, | ||
"rope_theta": 10000, | ||
"routed_scaling_factor": 2.5, | ||
"scoring_func": "sigmoid", | ||
"seq_aux": true, | ||
"tie_word_embeddings": false, | ||
"topk_group": 4, | ||
"topk_method": "noaux_tc", | ||
"dtype": "bfloat16", | ||
"transformers_version": "4.33.1", | ||
"use_cache": true, | ||
"v_head_dim": 128, | ||
"vocab_size": 129280, | ||
"using_flex_token": true, | ||
"using_fake_gate": true, | ||
"use_fused_rms_norm": true, | ||
"fuse_attention_ffn": true, | ||
"use_fused_rope": true, | ||
"token_drop_steps": 0, | ||
"recompute_fwd_gate_up": true, | ||
"adaptive_remained_O1_recompute_ratio": 0.3, | ||
"using_post_norm_recompute": true, | ||
"is_split_group_gemm": false, | ||
"use_dualpipev": true, | ||
"send_mtp_embed": true, | ||
"offline_quant_expert_weight": false, | ||
"clear_origin_weight_when_offline_quant": false, | ||
"dsv3_use_fp8_gemm": true, | ||
"dsv3_use_atten_recompute": true, | ||
"use_ds_gemm": false, | ||
"dsv3_use_fp8_dispatch": true, | ||
"fa_version": 3 | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放在这个目录下https://github.com/PaddlePaddle/PaddleFormers/tree/develop/examples/experiments/deepseek_v3_pretrain