Skip to content

Commit fd706e7

Browse files
hemildesaiclaude
authored andcommitted
fix: guard AutoConfig.from_pretrained in PP mask precomputation (#1378)
* fix: guard AutoConfig.from_pretrained in PP mask precomputation Wrap the AutoConfig.from_pretrained call in a try/except so that if it fails (e.g. network issues, invalid model name), the pipeline parallel mask precomputation is gracefully skipped with a warning instead of crashing the dataloader setup. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Hemil Desai <hemild@nvidia.com> * fix Signed-off-by: Hemil Desai <hemild@nvidia.com> --------- Signed-off-by: Hemil Desai <hemild@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
1 parent fa31028 commit fd706e7

File tree

3 files changed

+189
-23
lines changed

3 files changed

+189
-23
lines changed

examples/llm_finetune/deepseek_v32/deepseek_v32_hellaswag_pp.yaml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
# limitations under the License.
1414

1515
# Finetuning config for DeepSeek V3.2 on HellaSwag
16-
#
17-
# To run this recipe, please use the following command:
18-
# torchrun --nproc-per-node=8 recipes/llm/finetune.py --config examples/llm_finetune/deepseek_v32/deepseek_v32_hellaswag_pp.yaml
16+
# Replace /path/to/deepseek-v32 with the path to the DeepSeek V3.2 model
1917

2018
seed: 1234
2119

@@ -57,14 +55,16 @@ model:
5755
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_config
5856
config:
5957
_target_: nemo_automodel.components.models.deepseek_v32.config.DeepseekV32Config.from_pretrained
60-
pretrained_model_name_or_path: deepseek-ai/DeepSeek-V3.2
58+
pretrained_model_name_or_path: /path/to/deepseek-v32
59+
name_or_path: /path/to/deepseek-v32
6160
trust_remote_code: true
61+
load_base_model: true
6262
backend:
6363
_target_: nemo_automodel.components.models.common.BackendConfig
6464
attn: sdpa # TE requires the latest cudnn version so disabling by default
6565
linear: te
6666
rms_norm: te
67-
rope_fusion: true
67+
rope_fusion: false
6868
enable_deepep: true
6969
fake_balanced_gate: false
7070
enable_hf_state_dict_adapter: true
@@ -82,6 +82,9 @@ dataset:
8282
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
8383
path_or_dataset: rowan/hellaswag
8484
split: train
85+
tokenizer:
86+
_target_: transformers.AutoTokenizer.from_pretrained
87+
pretrained_model_name_or_path: /path/to/deepseek-v32
8588

8689
packed_sequence:
8790
packed_sequence_size: 0
@@ -97,6 +100,9 @@ validation_dataset:
97100
_target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
98101
path_or_dataset: rowan/hellaswag
99102
split: train
103+
tokenizer:
104+
_target_: transformers.AutoTokenizer.from_pretrained
105+
pretrained_model_name_or_path: /path/to/deepseek-v32
100106

101107
validation_dataloader:
102108
_target_: torchdata.stateful_dataloader.StatefulDataLoader

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -558,26 +558,32 @@ def build_dataloader(
558558
if pp_enabled:
559559
from nemo_automodel.components.datasets.utils import add_causal_masks_to_batch
560560

561-
hf_model_config = AutoConfig.from_pretrained(
562-
_get_model_name(cfg_model), trust_remote_code=compute_trust_remote_code_from_model(cfg_model)
563-
)
564-
565-
if "collate_fn" in dl_kwargs:
566-
# Case 1: PP enabled + collate_fn exists -> chain them
567-
# base_collate_fn -> add_causal_masks_to_batch
568-
base_collate_fn = dl_kwargs["collate_fn"]
561+
try:
562+
hf_model_config = AutoConfig.from_pretrained(
563+
_get_model_name(cfg_model), trust_remote_code=compute_trust_remote_code_from_model(cfg_model)
564+
)
565+
except Exception:
566+
logger.warning(
567+
"Failed to load model config for causal mask precomputation. "
568+
"Pipeline parallel mask precomputation will be skipped."
569+
)
570+
else:
571+
if "collate_fn" in dl_kwargs:
572+
# Case 1: PP enabled + collate_fn exists -> chain them
573+
# base_collate_fn -> add_causal_masks_to_batch
574+
base_collate_fn = dl_kwargs["collate_fn"]
569575

570-
def chained_collate_fn(batch, base_fn=base_collate_fn, config=hf_model_config):
571-
batch = base_fn(batch) # Apply base collate (padding, batching, etc.)
572-
batch = add_causal_masks_to_batch(batch, model_config=config) # Add masks
573-
return batch
576+
def chained_collate_fn(batch, base_fn=base_collate_fn, config=hf_model_config):
577+
batch = base_fn(batch) # Apply base collate (padding, batching, etc.)
578+
batch = add_causal_masks_to_batch(batch, model_config=config) # Add masks
579+
return batch
574580

575-
dl_kwargs["collate_fn"] = chained_collate_fn
576-
else:
577-
# Case 2: PP enabled + no collate_fn -> only add masks
578-
dl_kwargs["collate_fn"] = lambda batch, config=hf_model_config: add_causal_masks_to_batch(
579-
batch, model_config=config
580-
)
581+
dl_kwargs["collate_fn"] = chained_collate_fn
582+
else:
583+
# Case 2: PP enabled + no collate_fn -> only add masks
584+
dl_kwargs["collate_fn"] = lambda batch, config=hf_model_config: add_causal_masks_to_batch(
585+
batch, model_config=config
586+
)
581587

582588
try:
583589
import torch.multiprocessing as mp

tests/unit_tests/recipes/test_train_ft.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,160 @@ def test_build_model_and_optimizer_return_values():
11711171
# Tests for _get_model_name helper
11721172
# =============================================================================
11731173

1174+
# =============================================================================
1175+
# Tests for PP mask precomputation guard in build_dataloader
1176+
# =============================================================================
1177+
1178+
1179+
def test_build_dataloader_pp_autoconfig_failure_skips_mask_collate(caplog):
1180+
"""When AutoConfig.from_pretrained raises, mask precomputation is skipped and a warning is logged."""
1181+
cfg_ds = ConfigNode(
1182+
{
1183+
"_target_": "tests.unit_tests.recipes.test_train_ft.DummyIterableDataset",
1184+
"tokenizer": None,
1185+
"num_shards": 4,
1186+
}
1187+
)
1188+
cfg_dl = ConfigNode(
1189+
{
1190+
"_target_": "tests.unit_tests.recipes.test_train_ft.dl_factory_capture",
1191+
"num_workers": 0,
1192+
}
1193+
)
1194+
cfg_model = ConfigNode({"pretrained_model_name_or_path": "bad/model"})
1195+
cfg_ps = ConfigNode({})
1196+
1197+
with (
1198+
patch("nemo_automodel.recipes.llm.train_ft.AutoConfig.from_pretrained", side_effect=OSError("not found")),
1199+
caplog.at_level(logging.WARNING),
1200+
):
1201+
dl, tok = build_dataloader(
1202+
cfg_ds=cfg_ds,
1203+
cfg_dl=cfg_dl,
1204+
cfg_model=cfg_model,
1205+
cfg_ps=cfg_ps,
1206+
seed=123,
1207+
local_batch_size=2,
1208+
global_batch_size=4,
1209+
max_steps=None,
1210+
val_check_interval=None,
1211+
dp_rank=0,
1212+
dp_world_size=1,
1213+
pp_enabled=True,
1214+
)
1215+
1216+
assert "Failed to load model config for causal mask precomputation" in caplog.text
1217+
# collate_fn should NOT have been set since AutoConfig failed
1218+
mod = importlib.import_module("tests.unit_tests.recipes.test_train_ft")
1219+
captured = getattr(mod.dl_factory_capture, "captured")
1220+
assert "collate_fn" not in captured
1221+
1222+
1223+
def test_build_dataloader_pp_autoconfig_success_sets_mask_collate():
1224+
"""When AutoConfig.from_pretrained succeeds and no collate_fn exists, a mask-only collate is set."""
1225+
cfg_ds = ConfigNode(
1226+
{
1227+
"_target_": "tests.unit_tests.recipes.test_train_ft.DummyIterableDataset",
1228+
"tokenizer": None,
1229+
"num_shards": 4,
1230+
}
1231+
)
1232+
cfg_dl = ConfigNode(
1233+
{
1234+
"_target_": "tests.unit_tests.recipes.test_train_ft.dl_factory_capture",
1235+
"num_workers": 0,
1236+
}
1237+
)
1238+
cfg_model = ConfigNode({"pretrained_model_name_or_path": "good/model"})
1239+
cfg_ps = ConfigNode({})
1240+
1241+
mock_config = MagicMock()
1242+
with (
1243+
patch("nemo_automodel.recipes.llm.train_ft.AutoConfig.from_pretrained", return_value=mock_config),
1244+
patch("nemo_automodel.components.datasets.utils.add_causal_masks_to_batch", side_effect=lambda b, **kw: b),
1245+
):
1246+
dl, tok = build_dataloader(
1247+
cfg_ds=cfg_ds,
1248+
cfg_dl=cfg_dl,
1249+
cfg_model=cfg_model,
1250+
cfg_ps=cfg_ps,
1251+
seed=123,
1252+
local_batch_size=2,
1253+
global_batch_size=4,
1254+
max_steps=None,
1255+
val_check_interval=None,
1256+
dp_rank=0,
1257+
dp_world_size=1,
1258+
pp_enabled=True,
1259+
)
1260+
1261+
# collate_fn should have been set (mask-only path)
1262+
mod = importlib.import_module("tests.unit_tests.recipes.test_train_ft")
1263+
captured = getattr(mod.dl_factory_capture, "captured")
1264+
assert "collate_fn" in captured
1265+
assert callable(captured["collate_fn"])
1266+
1267+
1268+
def test_build_dataloader_pp_autoconfig_success_chains_existing_collate():
1269+
"""When AutoConfig.from_pretrained succeeds and collate_fn exists, they are chained."""
1270+
call_order = []
1271+
1272+
def my_collate(batch):
1273+
call_order.append("base")
1274+
return batch
1275+
1276+
cfg_ds = ConfigNode(
1277+
{
1278+
"_target_": "tests.unit_tests.recipes.test_train_ft.DummyIterableDataset",
1279+
"tokenizer": None,
1280+
"num_shards": 4,
1281+
}
1282+
)
1283+
cfg_dl = ConfigNode(
1284+
{
1285+
"_target_": "tests.unit_tests.recipes.test_train_ft.dl_factory_capture",
1286+
"num_workers": 0,
1287+
"collate_fn": my_collate,
1288+
}
1289+
)
1290+
cfg_model = ConfigNode({"pretrained_model_name_or_path": "good/model"})
1291+
cfg_ps = ConfigNode({})
1292+
1293+
mock_config = MagicMock()
1294+
1295+
def mock_add_masks(batch, model_config=None):
1296+
call_order.append("masks")
1297+
return batch
1298+
1299+
with (
1300+
patch("nemo_automodel.recipes.llm.train_ft.AutoConfig.from_pretrained", return_value=mock_config),
1301+
patch("nemo_automodel.components.datasets.utils.add_causal_masks_to_batch", side_effect=mock_add_masks),
1302+
):
1303+
dl, tok = build_dataloader(
1304+
cfg_ds=cfg_ds,
1305+
cfg_dl=cfg_dl,
1306+
cfg_model=cfg_model,
1307+
cfg_ps=cfg_ps,
1308+
seed=123,
1309+
local_batch_size=2,
1310+
global_batch_size=4,
1311+
max_steps=None,
1312+
val_check_interval=None,
1313+
dp_rank=0,
1314+
dp_world_size=1,
1315+
pp_enabled=True,
1316+
)
1317+
1318+
mod = importlib.import_module("tests.unit_tests.recipes.test_train_ft")
1319+
captured = getattr(mod.dl_factory_capture, "captured")
1320+
assert "collate_fn" in captured
1321+
chained_fn = captured["collate_fn"]
1322+
1323+
# Invoke the chained collate to verify ordering
1324+
chained_fn(["dummy_batch"])
1325+
assert call_order == ["base", "masks"]
1326+
1327+
11741328
@pytest.mark.parametrize("cfg_attrs,expected", [
11751329
# String config
11761330
({"config": "org/model-name"}, "org/model-name"),

0 commit comments

Comments
 (0)