Skip to content

Commit b2db55e

Browse files
HuiyingLiadil-a
andauthored
feat: add ministral3 configs and improve tie_emb detection (#915)
Signed-off-by: HuiyingLi <[email protected]> Co-authored-by: Adil <[email protected]>
1 parent ca5651e commit b2db55e

File tree

8 files changed

+398
-12
lines changed

8 files changed

+398
-12
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Configuration for fine-tuning ministral 3 14b
16+
17+
18+
19+
step_scheduler:
20+
global_batch_size: 8
21+
local_batch_size: 1
22+
ckpt_every_steps: 100
23+
val_every_steps: 100 # will run every x number of gradient steps
24+
num_epochs: 1
25+
26+
dist_env:
27+
backend: nccl
28+
timeout_minutes: 10
29+
30+
rng:
31+
_target_: nemo_automodel.components.training.rng.StatefulRNG
32+
seed: 42
33+
ranked: true
34+
35+
model:
36+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
37+
pretrained_model_name_or_path: mistralai/Ministral-3-14B-Reasoning-2512
38+
torch_dtype: torch.bfloat16
39+
attn_implementation: eager
40+
41+
checkpoint:
42+
enabled: true
43+
checkpoint_dir: vlm_checkpoints/
44+
model_save_format: safetensors
45+
save_consolidated: True
46+
47+
distributed:
48+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
49+
dp_size: none
50+
tp_size: 1
51+
cp_size: 1
52+
sequence_parallel: false
53+
54+
loss_fn:
55+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
56+
57+
dataset:
58+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
59+
path_or_dataset: mmoukouba/MedPix-VQA
60+
split: train[:1000]
61+
62+
dataloader:
63+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
64+
num_workers: 0
65+
collate_fn:
66+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
67+
68+
validation_dataset:
69+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
70+
path_or_dataset: mmoukouba/MedPix-VQA
71+
split: validation[:500]
72+
73+
validation_dataloader:
74+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
75+
num_workers: 0
76+
collate_fn:
77+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
78+
79+
optimizer:
80+
_target_: torch.optim.AdamW
81+
lr: 1e-5
82+
weight_decay: 0.01
83+
betas: [0.9, 0.95]
84+
85+
freeze_config:
86+
freeze_embeddings: true
87+
freeze_vision_tower: true
88+
freeze_language_model: false
89+
90+
# Uncomment and configure for W&B logging
91+
# wandb:
92+
# project:
93+
# entity:
94+
# name:
95+
# save_dir:
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Configuration for fine-tuning ministral 3 3b
16+
17+
18+
19+
step_scheduler:
20+
global_batch_size: 8
21+
local_batch_size: 1
22+
ckpt_every_steps: 100
23+
val_every_steps: 100 # will run every x number of gradient steps
24+
num_epochs: 1
25+
26+
dist_env:
27+
backend: nccl
28+
timeout_minutes: 10
29+
30+
rng:
31+
_target_: nemo_automodel.components.training.rng.StatefulRNG
32+
seed: 42
33+
ranked: true
34+
35+
model:
36+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
37+
pretrained_model_name_or_path: mistralai/Ministral-3-3B-Reasoning-2512
38+
torch_dtype: torch.bfloat16
39+
attn_implementation: eager
40+
41+
checkpoint:
42+
enabled: true
43+
checkpoint_dir: vlm_checkpoints/
44+
model_save_format: safetensors
45+
save_consolidated: True
46+
47+
distributed:
48+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
49+
dp_size: none
50+
tp_size: 1
51+
cp_size: 1
52+
sequence_parallel: false
53+
54+
loss_fn:
55+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
56+
57+
dataset:
58+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
59+
path_or_dataset: mmoukouba/MedPix-VQA
60+
split: train[:1000]
61+
62+
dataloader:
63+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
64+
num_workers: 0
65+
collate_fn:
66+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
67+
68+
validation_dataset:
69+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
70+
path_or_dataset: mmoukouba/MedPix-VQA
71+
split: validation[:500]
72+
73+
validation_dataloader:
74+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
75+
num_workers: 0
76+
collate_fn:
77+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
78+
79+
optimizer:
80+
_target_: torch.optim.AdamW
81+
lr: 1e-5
82+
weight_decay: 0.01
83+
betas: [0.9, 0.95]
84+
85+
freeze_config:
86+
freeze_embeddings: true
87+
freeze_vision_tower: true
88+
freeze_language_model: false
89+
90+
# Uncomment and configure for W&B logging
91+
# wandb:
92+
# project:
93+
# entity:
94+
# name:
95+
# save_dir:
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Configuration for fine-tuning ministral 3 8b
16+
17+
18+
19+
step_scheduler:
20+
global_batch_size: 8
21+
local_batch_size: 1
22+
ckpt_every_steps: 100
23+
val_every_steps: 100 # will run every x number of gradient steps
24+
num_epochs: 1
25+
26+
dist_env:
27+
backend: nccl
28+
timeout_minutes: 10
29+
30+
rng:
31+
_target_: nemo_automodel.components.training.rng.StatefulRNG
32+
seed: 42
33+
ranked: true
34+
35+
model:
36+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
37+
pretrained_model_name_or_path: mistralai/Ministral-3-8B-Reasoning-2512
38+
torch_dtype: torch.bfloat16
39+
attn_implementation: eager
40+
41+
checkpoint:
42+
enabled: true
43+
checkpoint_dir: vlm_checkpoints/
44+
model_save_format: safetensors
45+
save_consolidated: True
46+
47+
distributed:
48+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
49+
dp_size: none
50+
tp_size: 1
51+
cp_size: 1
52+
sequence_parallel: false
53+
54+
loss_fn:
55+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
56+
57+
dataset:
58+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
59+
path_or_dataset: mmoukouba/MedPix-VQA
60+
split: train[:1000]
61+
62+
dataloader:
63+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
64+
num_workers: 0
65+
collate_fn:
66+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
67+
68+
validation_dataset:
69+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
70+
path_or_dataset: mmoukouba/MedPix-VQA
71+
split: validation[:500]
72+
73+
validation_dataloader:
74+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
75+
num_workers: 0
76+
collate_fn:
77+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
78+
79+
optimizer:
80+
_target_: torch.optim.AdamW
81+
lr: 1e-5
82+
weight_decay: 0.01
83+
betas: [0.9, 0.95]
84+
85+
freeze_config:
86+
freeze_embeddings: true
87+
freeze_vision_tower: true
88+
freeze_language_model: false
89+
90+
# Uncomment and configure for W&B logging
91+
# wandb:
92+
# project:
93+
# entity:
94+
# name:
95+
# save_dir:

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from nemo_automodel.components.checkpoint.addons import ConsolidatedHFAddon, PeftAddon
4141
from nemo_automodel.components.checkpoint.stateful_wrappers import ModelState, OptimizerState
42+
from nemo_automodel.components.checkpoint.utils import is_tied_word_embeddings
4243

4344
if TYPE_CHECKING:
4445
from peft import PeftConfig
@@ -374,7 +375,7 @@ def load_base_model(
374375
key_mapping=getattr(model, "_checkpoint_conversion_mapping", None),
375376
)
376377

377-
is_tied_lm_head = getattr(getattr(model, "config", {}), "tie_word_embeddings", False)
378+
is_tied_lm_head = is_tied_word_embeddings(model)
378379
self.config.original_model_root_dir = root_dir
379380
if hasattr(model, "tie_weights") and is_tied_lm_head:
380381
model.tie_weights()

nemo_automodel/components/checkpoint/stateful_wrappers.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
set_optimizer_state_dict,
2525
)
2626

27+
from nemo_automodel.components.checkpoint.utils import is_tied_word_embeddings
28+
2729
_PREFIX = "model."
2830

2931

@@ -92,16 +94,7 @@ def __init__(
9294
- ["score."] for some classification heads
9395
"""
9496
self.model = [model] if isinstance(model, torch.nn.Module) else model
95-
self.is_tied_lm_head = getattr(getattr(self.model[0], "config", {}), "tie_word_embeddings", False)
96-
97-
non_tied_lm_head_models = {
98-
"Qwen3OmniMoeThinkerForConditionalGeneration", # complicated config structure
99-
"InternVLForConditionalGeneration", # even tho config says tie_word_embeddings=True, it's not
100-
}
101-
for m in non_tied_lm_head_models:
102-
if m in type(self.model[0]).__name__:
103-
self.is_tied_lm_head = False
104-
break
97+
self.is_tied_lm_head = is_tied_word_embeddings(self.model[0])
10598

10699
if self.is_tied_lm_head:
107100
_, lm_head_param_name = _get_lm_head_weight_and_name(self.model[0])
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch.nn as nn
16+
17+
18+
def is_tied_word_embeddings(model: nn.Module) -> bool:
19+
"""
20+
Check if the model's word embeddings are tied.
21+
22+
Args:
23+
model (nn.Module): The model to check.
24+
25+
Returns:
26+
bool: True if the model's word embeddings are tied, False otherwise.
27+
"""
28+
non_tied_lm_head_models = {
29+
"Qwen3OmniMoeThinkerForConditionalGeneration", # complicated config structure
30+
}
31+
model_class_name = type(model).__name__
32+
for m in non_tied_lm_head_models:
33+
if m in model_class_name:
34+
return False
35+
config = getattr(model, "config", None)
36+
text_config = getattr(config, "get_text_config", lambda: None)()
37+
return bool(getattr(text_config, "tie_word_embeddings", getattr(config, "tie_word_embeddings", False)))

tests/unit_tests/checkpoint/test_addons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self):
8989
_DummyModel.__name__ = "Qwen3OmniMoeThinkerForConditionalGeneration"
9090

9191
model = _DummyModel()
92-
state = ModelState(model)
92+
state = ModelState([model])
9393

9494
assert state.is_tied_lm_head is False
9595
assert not hasattr(state, "lm_head_param_name")

0 commit comments

Comments
 (0)