Skip to content

Commit b56af44

Browse files
authored
Fix reuse_grad_buf_for_mxfp8_param_ag for mxfp8 (#14445)
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
1 parent ddd61eb commit b56af44

File tree

5 files changed

+9
-2
lines changed

5 files changed

+9
-2
lines changed

nemo/collections/llm/recipes/precision/mixed_precision.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def bf16_with_mxfp8_mixed() -> run.Config[MegatronMixedPrecision]:
9999
cfg.fp8 = 'hybrid'
100100
cfg.fp8_recipe = "mxfp8"
101101
cfg.fp8_param_gather = True
102+
cfg.reuse_grad_buf_for_mxfp8_param_ag = True
102103
return cfg
103104

104105

@@ -112,6 +113,7 @@ def fp16_with_mxfp8_mixed() -> run.Config[MegatronMixedPrecision]:
112113
cfg.fp8 = 'hybrid'
113114
cfg.fp8_recipe = "mxfp8"
114115
cfg.fp8_param_gather = True
116+
cfg.reuse_grad_buf_for_mxfp8_param_ag = True
115117
return cfg
116118

117119

nemo/lightning/fabric/plugins.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
first_last_layers_bf16: bool = False,
6161
num_layers_at_start_in_bf16: int = 0,
6262
num_layers_at_end_in_bf16: int = 0,
63+
reuse_grad_buf_for_mxfp8_param_ag: bool = False,
6364
fp8_margin: int = 0,
6465
fp8_amax_history_len: int = 1,
6566
fp8_amax_compute_algo: str = "most_recent",
@@ -104,6 +105,7 @@ def __init__(
104105
first_last_layers_bf16=first_last_layers_bf16,
105106
num_layers_at_start_in_bf16=num_layers_at_start_in_bf16,
106107
num_layers_at_end_in_bf16=num_layers_at_end_in_bf16,
108+
reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag,
107109
fp8_margin=fp8_margin,
108110
fp8_amax_history_len=fp8_amax_history_len,
109111
fp8_amax_compute_algo=fp8_amax_compute_algo,

nemo/lightning/pytorch/plugins/mixed_precision.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class DtypeConfig:
8686
hysteresis: float = (None,)
8787
num_layers_at_start_in_bf16: int = 0
8888
num_layers_at_end_in_bf16: int = 0
89+
reuse_grad_buf_for_mxfp8_param_ag: bool = False
8990

9091

9192
class MegatronMixedPrecision(Precision):
@@ -122,6 +123,7 @@ def __init__(
122123
fp16_hysteresis: int = 2,
123124
num_layers_at_start_in_bf16: int = 0,
124125
num_layers_at_end_in_bf16: int = 0,
126+
reuse_grad_buf_for_mxfp8_param_ag: bool = False,
125127
) -> None:
126128
if fp8_params is not None:
127129
logging.warning(
@@ -161,6 +163,7 @@ def __init__(
161163
fp8_param_gather=fp8_param_gather,
162164
num_layers_at_start_in_bf16=num_layers_at_start_in_bf16,
163165
num_layers_at_end_in_bf16=num_layers_at_end_in_bf16,
166+
reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag,
164167
# fp16 loss scale
165168
loss_scale=fp16_loss_scale,
166169
initial_loss_scale=fp16_initial_loss_scale,

scripts/performance/helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ def set_precision_configs(recipe, compute_dtype: str, fp8_recipe: str | None = N
226226
# Enable reuse_grad_buf_for_mxfp8_param_ag for MXFP8 and disable AG overlap
227227
# because it is not supported with reuse_grad_buf_for_mxfp8_param_ag
228228
if compute_dtype.lower() == "fp8" and fp8_recipe.lower() == "mxfp8":
229-
recipe.trainer.strategy.ddp.reuse_grad_buf_for_mxfp8_param_ag = True
230-
recipe.optim.config.reuse_grad_buf_for_mxfp8_param_ag = True
231229
comm_overlap_callback_idx = get_comm_overlap_callback_idx(recipe.trainer.callbacks)
232230
if comm_overlap_callback_idx is not None:
233231
recipe.trainer.callbacks[comm_overlap_callback_idx].overlap_param_gather = False

tests/collections/llm/recipes/test_mixed_precision.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_bf16_with_mxfp8_mixed_config():
8787
assert config.fp8 == "hybrid"
8888
assert config.fp8_recipe == "mxfp8"
8989
assert config.fp8_param_gather is True
90+
assert config.reuse_grad_buf_for_mxfp8_param_ag is True
9091

9192

9293
def test_fp16_with_mxfp8_mixed_config():
@@ -99,6 +100,7 @@ def test_fp16_with_mxfp8_mixed_config():
99100
assert config.fp8 == "hybrid"
100101
assert config.fp8_recipe == "mxfp8"
101102
assert config.fp8_param_gather is True
103+
assert config.reuse_grad_buf_for_mxfp8_param_ag is True
102104

103105

104106
def test_bf16_with_fp8_current_scaling_mixed_config():

0 commit comments

Comments
 (0)