Skip to content

Commit 752edac

Browse files
authored
Split out updates to model config into its own function and call it during setup (#12663)
* move ddp config updates to setup Signed-off-by: Ananth Subramaniam <[email protected]> * split into own function Signed-off-by: Ananth Subramaniam <[email protected]> * names Signed-off-by: Ananth Subramaniam <[email protected]> * add copyright for example Signed-off-by: Ananth Subramaniam <[email protected]> * address comments Signed-off-by: Ananth Subramaniam <[email protected]> * remove changes to get_model_from_config Signed-off-by: Ananth Subramaniam <[email protected]> * update model config sync funcs as part of get_model_from_config Signed-off-by: Ananth Subramaniam <[email protected]> * move back to setup, include grad scale func from optimizer Signed-off-by: Ananth Subramaniam <[email protected]> * lint Signed-off-by: Ananth Subramaniam <[email protected]> * undo Signed-off-by: Ananth Subramaniam <[email protected]> * right check Signed-off-by: Ananth Subramaniam <[email protected]> * keep order Signed-off-by: Ananth Subramaniam <[email protected]> * keep comment Signed-off-by: Ananth Subramaniam <[email protected]> * lints Signed-off-by: Ananth Subramaniam <[email protected]> * remove arg Signed-off-by: Ananth Subramaniam <[email protected]> * updates Signed-off-by: Ananth Subramaniam <[email protected]> --------- Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent ea484e3 commit 752edac

File tree

4 files changed

+99
-60
lines changed

4 files changed

+99
-60
lines changed

nemo/tron/examples/lingua-1b_dclm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import math
216
from functools import partial
317

nemo/tron/model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@
2727
from nemo.collections.llm.t5.model.t5 import T5Config
2828

2929

30-
def _get_model_type(model_config: GPTConfig | T5Config) -> ModelType:
31-
return ModelType.encoder_and_decoder if isinstance(model_config, T5Config) else ModelType.encoder_or_decoder
32-
33-
3430
def get_model_from_config(
3531
model_config: GPTConfig | T5Config,
3632
ddp_config: DistributedDataParallelConfig,
@@ -151,5 +147,8 @@ def get_model_from_config(
151147
if data_parallel_random_init:
152148
for model_module in model:
153149
model_module.broadcast_params()
154-
155150
return model
151+
152+
153+
def _get_model_type(model_config: GPTConfig | T5Config) -> ModelType:
154+
return ModelType.encoder_and_decoder if isinstance(model_config, T5Config) else ModelType.encoder_or_decoder

nemo/tron/setup.py

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,24 @@
1313
# limitations under the License.
1414

1515
import time
16-
from typing import Any, NamedTuple, Optional
16+
from typing import Any, Dict, NamedTuple, Optional
1717

1818
import torch
19+
from megatron.core.distributed import (
20+
DistributedDataParallel,
21+
DistributedDataParallelConfig,
22+
finalize_model_grads,
23+
)
1924
from megatron.core.optimizer import MegatronOptimizer
2025
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
2126
from megatron.core.rerun_state_machine import RerunDataIterator
2227
from megatron.core.transformer import MegatronModule
2328

29+
from nemo.collections.llm.gpt.model.base import GPTConfig
30+
from nemo.collections.llm.t5.model.t5 import T5Config
2431
from nemo.tron import fault_tolerance
2532
from nemo.tron.checkpointing import checkpoint_exists, load_checkpoint
26-
from nemo.tron.config import ConfigContainer
33+
from nemo.tron.config import CheckpointConfig, ConfigContainer
2734
from nemo.tron.data.dataset import setup_data_iterators
2835
from nemo.tron.init import initialize_megatron, set_jit_fusion_options
2936
from nemo.tron.model import get_model_from_config
@@ -95,35 +102,7 @@ def setup(
95102
barrier_and_log("after megatron is initialized")
96103

97104
# Context used for persisting some state between checkpoint saves.
98-
if cfg.checkpoint_config.non_persistent_ckpt_type == "local":
99-
if HAVE_RESIL:
100-
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
101-
LocalCheckpointManager,
102-
)
103-
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
104-
CliqueReplicationStrategy,
105-
)
106-
else:
107-
raise RuntimeError(
108-
"The 'nvidia_resiliency_ext' module is required for local "
109-
"checkpointing but was not found. Please ensure it is installed."
110-
)
111-
if cfg.checkpoint_config.replication:
112-
repl_strategy = CliqueReplicationStrategy.from_replication_params(
113-
cfg.checkpoint_config.replication_jump,
114-
cfg.checkpoint_config.replication_factor,
115-
)
116-
else:
117-
repl_strategy = None
118-
119-
checkpointing_context = {
120-
"local_checkpoint_manager": LocalCheckpointManager(
121-
cfg.checkpoint_config.non_persistent_local_ckpt_dir,
122-
repl_strategy=repl_strategy,
123-
)
124-
}
125-
else:
126-
checkpointing_context = {}
105+
checkpointing_context = _init_checkpointing_context(cfg.checkpoint_config)
127106

128107
# Tokenizer
129108
timers("tokenizer-setup", log_level=0).start(barrier=True)
@@ -146,9 +125,16 @@ def setup(
146125
overlap_param_gather_with_optimizer_step=cfg.optimizer_config.overlap_param_gather_with_optimizer_step,
147126
data_parallel_random_init=cfg.rng_config.data_parallel_random_init,
148127
)
128+
cfg.model_config.timers = timers
149129
cfg.optimizer_config.timers = timers
150130
optimizer, scheduler = setup_optimizer(cfg, model)
151-
131+
_update_model_config_funcs(
132+
model,
133+
cfg.model_config,
134+
cfg.ddp_config,
135+
optimizer,
136+
align_grad_reduce=cfg.dist_config.align_grad_reduce,
137+
)
152138
timers("model-and-optimizer-setup").stop()
153139
barrier_and_log("after model, optimizer, and learning rate scheduler are built")
154140

@@ -199,3 +185,65 @@ def setup(
199185
test_data_iterator,
200186
checkpointing_context,
201187
)
188+
189+
def _init_checkpointing_context(checkpoint_config: CheckpointConfig) -> Dict[str, Any]:
190+
# Context used for persisting some state between checkpoint saves.
191+
if checkpoint_config.non_persistent_ckpt_type != "local":
192+
return {}
193+
194+
if not HAVE_RESIL:
195+
raise RuntimeError(
196+
"The 'nvidia_resiliency_ext' module is required for local "
197+
"checkpointing but was not found. Please ensure it is installed."
198+
)
199+
200+
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
201+
LocalCheckpointManager,
202+
)
203+
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
204+
CliqueReplicationStrategy,
205+
)
206+
if checkpoint_config.replication:
207+
repl_strategy = CliqueReplicationStrategy.from_replication_params(
208+
checkpoint_config.replication_jump,
209+
checkpoint_config.replication_factor,
210+
)
211+
else:
212+
repl_strategy = None
213+
214+
checkpointing_context = {
215+
"local_checkpoint_manager": LocalCheckpointManager(
216+
checkpoint_config.non_persistent_local_ckpt_dir,
217+
repl_strategy=repl_strategy,
218+
)
219+
}
220+
return checkpointing_context
221+
222+
223+
def _update_model_config_funcs(
224+
model: MegatronModule,
225+
model_config: GPTConfig | T5Config,
226+
ddp_config: DistributedDataParallelConfig,
227+
optimizer: MegatronOptimizer,
228+
*,
229+
align_grad_reduce: bool = True
230+
) -> None:
231+
"""Update model config sync funcs based on initialized model."""
232+
if isinstance(model[0], DistributedDataParallel) and ddp_config.overlap_grad_reduce:
233+
assert model_config.no_sync_func is None, (
234+
"When overlap_grad_reduce is True, config.no_sync_func must be None; "
235+
"a custom no_sync_func is not supported when overlapping grad-reduce"
236+
)
237+
model_config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
238+
if len(model) == 1:
239+
model_config.no_sync_func = model_config.no_sync_func[0]
240+
if align_grad_reduce:
241+
model_config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
242+
if len(model) == 1:
243+
model_config.grad_sync_func = model_config.grad_sync_func[0]
244+
if ddp_config.overlap_param_gather and ddp_config.align_param_gather:
245+
model_config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
246+
if len(model) == 1:
247+
model_config.param_sync_func = model_config.param_sync_func[0]
248+
model_config.finalize_model_grads_func = finalize_model_grads
249+
model_config.grad_scale_func = optimizer.scale_loss

nemo/tron/train.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import torch
2424
from megatron.core import parallel_state
2525
from megatron.core.distributed import DistributedDataParallel as DDP
26-
from megatron.core.distributed import finalize_model_grads
2726
from megatron.core.num_microbatches_calculator import (
2827
get_current_global_batch_size,
2928
get_current_running_global_batch_size,
@@ -84,27 +83,6 @@ def train(
8483

8584
num_floating_point_operations_so_far = global_state.train_state.floating_point_operations_so_far
8685
num_floating_point_operations_since_last_log_event = 0.0
87-
model_config.grad_scale_func = optimizer.scale_loss
88-
model_config.timers = timers
89-
90-
ddp_config = config.ddp_config
91-
if isinstance(model[0], DDP) and ddp_config.overlap_grad_reduce:
92-
assert model_config.no_sync_func is None, (
93-
"When overlap_grad_reduce is True, config.no_sync_func must be None; "
94-
"a custom no_sync_func is not supported when overlapping grad-reduce"
95-
)
96-
model_config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
97-
if len(model) == 1:
98-
model_config.no_sync_func = model_config.no_sync_func[0]
99-
if config.dist_config.align_grad_reduce:
100-
model_config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
101-
if len(model) == 1:
102-
model_config.grad_sync_func = model_config.grad_sync_func[0]
103-
if ddp_config.overlap_param_gather and ddp_config.align_param_gather:
104-
model_config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
105-
if len(model) == 1:
106-
model_config.param_sync_func = model_config.param_sync_func[0]
107-
model_config.finalize_model_grads_func = finalize_model_grads
10886

10987
timers("interval-time", log_level=0).start(barrier=True)
11088
report_memory_flag = True
@@ -160,7 +138,7 @@ def train(
160138

161139
start_iteration = global_state.train_state.step
162140
should_toggle_forward_pre_hook = (
163-
config.optimizer_config.use_distributed_optimizer and ddp_config.overlap_param_gather
141+
config.optimizer_config.use_distributed_optimizer and config.ddp_config.overlap_param_gather
164142
)
165143
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
166144
# or random initialization don't propagate to all ranks in first all-gather (which is a

0 commit comments

Comments
 (0)