Skip to content

Commit 8c4c5ae

Browse files
authored
feat: Add deepseek flops tracker (original NVIDIA-NeMo#1250) (NVIDIA-NeMo#1305)
Signed-off-by: Guyue Huang <guyueh@nvidia.com> Signed-off-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com>
1 parent d3a61da commit 8c4c5ae

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

nemo_rl/utils/flops_tracker.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
2626

2727
from nemo_rl.models.policy.utils import sliding_window_overwrite
28-
from nemo_rl.utils.flops_formulas import FLOPSConfig, llama, qwen2, qwen3
28+
from nemo_rl.utils.flops_formulas import FLOPSConfig, deepseekv3, llama, qwen2, qwen3
2929

3030

3131
def get_default_hf_config(model_name: str) -> PretrainedConfig:
@@ -77,6 +77,27 @@ def convert_config_to_flops_config(
7777
attention_heads=config.num_attention_heads,
7878
vocab_size=config.vocab_size,
7979
), llama
80+
elif config.__class__.model_type == "deepseek_v3":
81+
return FLOPSConfig(
82+
gbs=0,
83+
hs=config.hidden_size,
84+
layers=config.num_hidden_layers,
85+
ffn_hs=config.intermediate_size,
86+
attention_heads=config.num_attention_heads,
87+
moe_router_topk=config.num_experts_per_tok,
88+
query_groups=config.num_key_value_heads,
89+
vocab_size=config.vocab_size,
90+
q_lora_rank=config.q_lora_rank,
91+
kv_lora_rank=config.kv_lora_rank,
92+
qk_head_dim=config.qk_nope_head_dim,
93+
qk_pos_emb_head_dim=config.qk_rope_head_dim,
94+
v_head_dim=config.v_head_dim,
95+
moe_layer_freq=1,
96+
moe_shared_expert_intermediate_size=config.moe_intermediate_size,
97+
moe_ffn_hidden_size=config.moe_intermediate_size,
98+
mtp_num_layers=0,
99+
causal_self_attn=True,
100+
), deepseekv3
80101
else:
81102
raise ValueError(f"Unsupported config type: {type(config)}")
82103

tests/unit/utils/test_flops_counter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
("meta-llama/Llama-3.1-405B-Instruct", 128, 8192, 2.65e18),
2929
("Qwen/Qwen3-30B-A3B", 128, 4096, 9.37e15),
3030
("Qwen/Qwen3-235B-A22B", 128, 4096, 6.21e16),
31+
("deepseek-ai/DeepSeek-V3", 1, 4096, 1.023e15),
3132
],
3233
)
3334
def test_flops_counter(model_name, gbs, seqlen, expected_flops):

0 commit comments

Comments
 (0)