Skip to content

Commit 1645e61

Browse files
authored
feat: nano v3 configs and FSDP fix (#964)
Signed-off-by: adil-a <[email protected]>
1 parent 1efd3e8 commit 1645e61

File tree

9 files changed

+244
-7
lines changed

9 files changed

+244
-7
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# To run this recipe, please use the following command:
17+
# torchrun --nproc-per-node=8 recipes/llm_finetune/finetune.py --config recipes/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml
18+
19+
20+
step_scheduler:
21+
global_batch_size: 16
22+
local_batch_size: 1
23+
ckpt_every_steps: 1000
24+
val_every_steps: 1000 # will run every x number of gradient steps
25+
max_steps: 100
26+
27+
dist_env:
28+
backend: nccl
29+
timeout_minutes: 1
30+
31+
rng:
32+
_target_: nemo_automodel.components.training.rng.StatefulRNG
33+
seed: 1111
34+
ranked: true
35+
36+
model:
37+
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
38+
pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
39+
40+
# torch.compile configuration
41+
compile:
42+
enabled: false
43+
mode: "default" # Options: "default", "reduce-overhead", "max-autotune"
44+
fullgraph: false
45+
dynamic: true # Set to false for better performance with fixed shapes
46+
backend: null # Use default backend (inductor)
47+
48+
distributed:
49+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
50+
dp_size: none
51+
dp_replicate_size: 1 # dp_shard_size = dp_size / dp_replicate_size and dp_shard_size < dp_size. For DDP usecase, use DDPManager
52+
tp_size: 1
53+
cp_size: 1
54+
sequence_parallel: false
55+
defer_fsdp_grad_sync: false
56+
57+
loss_fn:
58+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
59+
60+
dataset:
61+
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
62+
dataset_name: rajpurkar/squad
63+
split: train
64+
65+
packed_sequence:
66+
packed_sequence_size: 0
67+
68+
dataloader:
69+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
70+
collate_fn: nemo_automodel.components.datasets.utils.default_collater
71+
shuffle: True
72+
73+
validation_dataset:
74+
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
75+
dataset_name: rajpurkar/squad
76+
split: validation
77+
limit_dataset_samples: 64
78+
79+
validation_dataloader:
80+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
81+
collate_fn: nemo_automodel.components.datasets.utils.default_collater
82+
83+
optimizer:
84+
_target_: torch.optim.Adam
85+
betas: [0.9, 0.999]
86+
eps: 1e-8
87+
lr: 1.0e-5
88+
weight_decay: 0
89+
90+
lr_scheduler:
91+
lr_decay_style: cosine
92+
min_lr: 1.0e-6
93+
94+
# wandb:
95+
# project: <your_wandb_project>
96+
# entity: <your_wandb_entity>
97+
# name: <your_wandb_exp_name>
98+
# save_dir: <your_wandb_save_dir>
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# To run this recipe, please use the following command:
17+
# torchrun --nproc-per-node=8 recipes/llm_finetune/finetune.py --config recipes/llm_finetune/nemotron/llama3_3_nemotron_super_49B_squad_peft.yaml
18+
19+
20+
step_scheduler:
21+
global_batch_size: 8
22+
local_batch_size: 1
23+
ckpt_every_steps: 1000
24+
val_every_steps: 1000 # will run every x number of gradient steps
25+
max_steps: 100
26+
27+
dist_env:
28+
backend: nccl
29+
timeout_minutes: 1
30+
31+
rng:
32+
_target_: nemo_automodel.components.training.rng.StatefulRNG
33+
seed: 1111
34+
ranked: true
35+
36+
model:
37+
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
38+
pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
39+
40+
# torch.compile configuration
41+
compile:
42+
enabled: false
43+
mode: "default" # Options: "default", "reduce-overhead", "max-autotune"
44+
fullgraph: false
45+
dynamic: true # Set to false for better performance with fixed shapes
46+
backend: null # Use default backend (inductor)
47+
48+
peft:
49+
_target_: nemo_automodel.components._peft.lora.PeftConfig
50+
match_all_linear: True
51+
dim: 8
52+
alpha: 32
53+
use_triton: True
54+
55+
distributed:
56+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
57+
dp_size: none
58+
dp_replicate_size: 1 # dp_shard_size = dp_size / dp_replicate_size and dp_shard_size < dp_size. For DDP usecase, use DDPManager
59+
tp_size: 1
60+
cp_size: 1
61+
sequence_parallel: false
62+
63+
loss_fn:
64+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
65+
66+
dataset:
67+
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
68+
dataset_name: rajpurkar/squad
69+
split: train
70+
71+
packed_sequence:
72+
packed_sequence_size: 0
73+
74+
dataloader:
75+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
76+
collate_fn: nemo_automodel.components.datasets.utils.default_collater
77+
shuffle: True
78+
79+
validation_dataset:
80+
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
81+
dataset_name: rajpurkar/squad
82+
split: validation
83+
limit_dataset_samples: 64
84+
85+
validation_dataloader:
86+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
87+
collate_fn: nemo_automodel.components.datasets.utils.default_collater
88+
89+
optimizer:
90+
_target_: torch.optim.Adam
91+
betas: [0.9, 0.999]
92+
eps: 1e-8
93+
lr: 1.0e-5
94+
weight_decay: 0
95+
96+
lr_scheduler:
97+
lr_decay_style: cosine
98+
min_lr: 1.0e-6
99+
100+
# wandb:
101+
# project: <your_wandb_project>
102+
# entity: <your_wandb_entity>
103+
# name: <your_wandb_exp_name>
104+
# save_dir: <your_wandb_save_dir>

nemo_automodel/components/distributed/fsdp2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ class FSDP2Manager:
128128
metadata={"help": "Enable activation checkpointing if True. Applies to linear layers."},
129129
)
130130

131+
defer_fsdp_grad_sync: Optional[bool] = field(
132+
default=True,
133+
metadata={"help": "Defer FSDP gradient sync to only the final micro-batch before the optimizer step if True."},
134+
)
135+
131136
def __post_init__(self):
132137
"""
133138
Post-initialization hook that sets up the distributed environment.

nemo_automodel/components/distributed/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,16 @@ def reduce_loss(
213213
return loss, denominator
214214

215215

216-
def get_sync_ctx(model, is_optim_step):
216+
def get_sync_ctx(model, is_optim_step, defer_fsdp_grad_sync: bool):
217217
"""
218218
Get the synchronization context for the model.
219219
220220
Args:
221221
model: The model to synchronize.
222222
is_optim_step: Whether the current step is an optimizer step.
223+
defer_fsdp_grad_sync: Controls FSDP2 gradient synchronization during gradient accumulation.
224+
- True: disable gradient sync on non-final micro-batches (saves comm, can increase peak memory).
225+
- False: always sync gradients on every micro-batch (more comm, lower peak memory).
223226
224227
Returns:
225228
A context manager that synchronizes the model.
@@ -229,7 +232,10 @@ def get_sync_ctx(model, is_optim_step):
229232
# all-reduce for every micro-batch and greatly improves throughput.
230233
sync_ctx = nullcontext()
231234
if isinstance(model, dist.fsdp._fully_shard._fully_shard.FSDPModule):
232-
model.set_requires_gradient_sync(is_optim_step)
235+
if defer_fsdp_grad_sync:
236+
model.set_requires_gradient_sync(is_optim_step)
237+
else:
238+
model.set_requires_gradient_sync(True)
233239
elif isinstance(model, torch.nn.parallel.DistributedDataParallel) and not is_optim_step:
234240
sync_ctx = model.no_sync()
235241
return sync_ctx

nemo_automodel/recipes/llm/kd.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,15 @@ def _forward_backward_step(
173173
train_ctx, batch = make_cp_batch_and_ctx(self.device_mesh, batch, labels)
174174

175175
model = self.model_parts[0]
176-
sync_ctx = get_sync_ctx(model, idx == num_batches - 1) if is_train else nullcontext()
176+
sync_ctx = (
177+
get_sync_ctx(
178+
model,
179+
idx == num_batches - 1,
180+
defer_fsdp_grad_sync=getattr(self.model_wrapper, "defer_fsdp_grad_sync", True),
181+
)
182+
if is_train
183+
else nullcontext()
184+
)
177185
with train_ctx(), sync_ctx:
178186
# No grad for teacher forward
179187
with (

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,15 @@ def _forward_backward_step(
12211221
loss_buffer.append(local_loss.clone().detach())
12221222
else:
12231223
model = self.model_parts[0]
1224-
sync_ctx = get_sync_ctx(model, idx == num_batches - 1) if is_train else nullcontext()
1224+
sync_ctx = (
1225+
get_sync_ctx(
1226+
model,
1227+
idx == num_batches - 1,
1228+
defer_fsdp_grad_sync=getattr(self.model_wrapper, "defer_fsdp_grad_sync", True),
1229+
)
1230+
if is_train
1231+
else nullcontext()
1232+
)
12251233
with train_ctx(), sync_ctx:
12261234
if isinstance(self.loss_fn, FusedLinearCrossEntropy):
12271235
# use num_logits_to_keep to avoid full logits matrix in memory

nemo_automodel/recipes/vlm/finetune.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,14 @@ def _run_train_optim_step(self, batches, max_grad_norm: Optional[float] = None):
755755
labels = batch.pop("labels")
756756

757757
train_ctx, batch = make_cp_batch_and_ctx(self.device_mesh, batch, labels)
758-
with train_ctx(), get_sync_ctx(self.model, i == num_batches - 1):
758+
with (
759+
train_ctx(),
760+
get_sync_ctx(
761+
self.model,
762+
i == num_batches - 1,
763+
defer_fsdp_grad_sync=getattr(self.model_wrapper, "defer_fsdp_grad_sync", True),
764+
),
765+
):
759766
if isinstance(self.loss_fn, FusedLinearCrossEntropy):
760767
# use num_logits_to_keep to avoid full logits matrix in memory
761768
out = self.model(logits_to_keep=1, **batch)

tests/unit_tests/distributed/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_get_sync_ctx(monkeypatch, patch_dist):
117117
class Plain(torch.nn.Linear):
118118
pass
119119

120-
ctx = du.get_sync_ctx(Plain(2, 2), is_optim_step=False)
120+
ctx = du.get_sync_ctx(Plain(2, 2), is_optim_step=False, defer_fsdp_grad_sync=False)
121121
# entering/exiting the context must be a no-op
122122
with ctx:
123123
pass

tests/unit_tests/recipes/test_finetune_vlm_helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def test_run_train_step_supports_tensor_outputs(monkeypatch):
228228
recipe.cfg = _Cfg(fp8=None)
229229
recipe.lr_scheduler = None
230230
recipe.timestamp = 0.0
231+
recipe.model_wrapper = None
231232

232233
recipe._dp_allreduce = lambda tensor, include_cp=False: tensor
233234
recipe._get_dp_group_size = lambda include_cp=True: 1
@@ -251,7 +252,7 @@ def fake_calculate_loss(*args, **kwargs):
251252
)
252253
monkeypatch.setattr(
253254
"nemo_automodel.recipes.vlm.finetune.get_sync_ctx",
254-
lambda model, is_last: nullcontext(),
255+
lambda model, is_last, defer_fsdp_grad_sync=True: nullcontext(),
255256
)
256257

257258
calculate_mock = MagicMock(side_effect=fake_calculate_loss)

0 commit comments

Comments
 (0)