Skip to content

Commit e276910

Browse files
dingqingy-nvko3n1g
andauthored
DSV3 NVFP4 recipe on GB300 (#2076)
Signed-off-by: Dingqing Yang <[email protected]> Co-authored-by: oliver könig <[email protected]>
1 parent 8a937f6 commit e276910

File tree

5 files changed

+23
-4
lines changed

5 files changed

+23
-4
lines changed

scripts/performance/configs/deepseek/deepseek_llm_pretrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def deepseek_v3_pretrain_config_gb300(
6060
pipeline_model_parallel_size=base_cfg.pipeline_model_parallel_size,
6161
virtual_pipeline_model_parallel_size=base_cfg.virtual_pipeline_model_parallel_size,
6262
moe_flex_dispatcher_backend=base_cfg.moe_flex_dispatcher_backend,
63-
layout=None,
63+
layout=base_cfg.pp_layout,
6464
)
6565
set_deepseek_v3_common_configs(cfg)
6666
set_workload_base_configs(cfg, base_cfg)

scripts/performance/configs/deepseek/deepseek_workload_base_configs.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,16 @@
5454
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_V1 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V1
5555
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_V1 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V1
5656
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_V1 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V1
57-
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_V1 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V1
57+
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_V1 = replace(
58+
BASE_DEEPSEEK_V3_CONFIG,
59+
micro_batch_size=2,
60+
pipeline_model_parallel_size=2,
61+
virtual_pipeline_model_parallel_size=8,
62+
pp_layout="Et*4|(t*4|)*14tmL",
63+
expert_model_parallel_size=32,
64+
cuda_graph_scope=[],
65+
recompute_modules=["mla_up_proj"],
66+
)
5867

5968

6069
DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_V1 = replace(
@@ -133,7 +142,10 @@
133142
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_BF16_V2 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V2
134143
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_CS_V2 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V2
135144
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_FP8_MX_V2 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V2
136-
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_V2 = DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_V2
145+
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_V2 = replace(
146+
DEEPSEEK_V3_PRETRAIN_CONFIG_GB300_NVFP4_V1,
147+
global_batch_size=4096,
148+
)
137149

138150

139151
DEEPSEEK_V3_PRETRAIN_CONFIG_GB200_V2 = replace(

scripts/performance/perf_plugins.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,10 @@ def setup(self, task: Union["run.Partial", "run.Script"], executor: "run.Executo
446446
self.train_task,
447447
)
448448

449+
# Set NVFP4-specific environment variables
450+
if self.compute_dtype == "nvfp4":
451+
executor.env_vars["NVTE_USE_FAST_MATH"] = "1"
452+
449453

450454
@dataclass
451455
class PyTorchProfilerPluginScriptArgs:

scripts/performance/utils/overrides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def set_post_overrides(
374374
dp = int(num_gpus / (tp * pp * cp))
375375
logger.info(f"DP: {dp}; TP: {tp}; PP: {pp}; CP: {cp}; VP: {vp}")
376376
## NOTE: overlap_param_gather_with_optimizer_step causes NaN grad norm for fp8_mx. Disabling it until the issue is resolved.
377-
if dp > 1 and pp > 1 and vp > 1 and compute_dtype != "fp8_mx":
377+
if dp > 1 and pp > 1 and vp > 1 and compute_dtype not in ("fp8_mx", "nvfp4"):
378378
recipe.optimizer.overlap_param_gather_with_optimizer_step = True
379379
if hasattr(recipe, "comm_overlap") and isinstance(recipe.comm_overlap, CommOverlapConfig):
380380
recipe.comm_overlap.overlap_param_gather_with_optimizer_step = True

scripts/performance/utils/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class WorkloadBaseConfig:
6262
moe_a2a_overlap: Optional[bool] = False
6363
peft: Optional[str] = None
6464

65+
# Pipeline parallelism layout
66+
pp_layout: Optional[str] = None
67+
6568
@property
6669
def sequence_parallel(self) -> bool:
6770
"""Get the sequence parallel flag."""

0 commit comments

Comments
 (0)