Skip to content

Commit f942ec3

Browse files
HuiyingLiakoumpa
andauthored
feat: support kimi-vl model (#1103)
Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Co-authored-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 6fa2b47 commit f942ec3

File tree

10 files changed

+2696
-5
lines changed

10 files changed

+2696
-5
lines changed

docs/model-coverage/vlm.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ NeMo Automodel supports [AutoModelForImageTextToText](https://huggingface.co/doc
2525

2626
| Model | Dataset | FSDP2 | PEFT | Example YAML |
2727
|------------------------------------|-----------------------------|------------|------------|--------------|
28+
| Kimi-VL-A3B-Instruct | cord-v2 | Supported | Supported | [kimi2vl_cordv2.yaml](../../examples/vlm_finetune/kimi/kimi2vl_cordv2.yaml) |
2829
| Gemma 3-4B & 27B | naver-clova-ix & rdr-items | Supported | Supported | [gemma3_vl_4b_cord_v2.yaml](../../examples/vlm_finetune/gemma3/gemma3_vl_4b_cord_v2.yaml) |
2930
| Gemma 3n | naver-clova-ix & rdr-items | Supported | Supported | [gemma3n_vl_4b_medpix.yaml](../../examples/vlm_finetune/gemma3n/gemma3n_vl_4b_medpix.yaml) |
3031
| Qwen2-VL-2B-Instruct & Qwen2.5-VL-3B-Instruct | cord-v2 | Supported | Supported | [qwen2_5_vl_3b_rdr.yaml](../../examples/vlm_finetune/qwen2_5/qwen2_5_vl_3b_rdr.yaml) |
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
step_scheduler:
16+
global_batch_size: 16
17+
local_batch_size: 2
18+
ckpt_every_steps: 100
19+
val_every_steps: 100
20+
max_steps: 50
21+
22+
dist_env:
23+
backend: nccl
24+
timeout_minutes: 10
25+
26+
rng:
27+
_target_: nemo_automodel.components.training.rng.StatefulRNG
28+
seed: 42
29+
ranked: true
30+
31+
distributed:
32+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
33+
tp_size: 1
34+
cp_size: 1
35+
pp_size: 1
36+
dp_replicate_size: 1
37+
ep_size: 8
38+
sequence_parallel: false
39+
40+
autopipeline:
41+
_target_: nemo_automodel.components.distributed.pipelining.AutoPipeline
42+
pp_schedule: interleaved1f1b
43+
pp_microbatch_size: 2
44+
round_virtual_stages_to_pp_multiple: down
45+
scale_grads_in_schedule: false
46+
layers_per_stage: 7
47+
patch_inner_model: false
48+
patch_causal_lm_model: false
49+
50+
parallelizer:
51+
_target_: nemo_automodel.components.moe.parallelizer.parallelize_model
52+
activation_checkpointing: false
53+
54+
model:
55+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
56+
pretrained_model_name_or_path: moonshotai/Kimi-VL-A3B-Instruct
57+
backend:
58+
_target_: nemo_automodel.components.models.common.BackendConfig
59+
attn: te
60+
linear: te
61+
rms_norm: te
62+
rope_fusion: true
63+
enable_deepep: true
64+
fake_balanced_gate: false
65+
enable_hf_state_dict_adapter: true
66+
enable_fsdp_optimizations: true
67+
68+
processor:
69+
_target_: transformers.AutoProcessor.from_pretrained
70+
pretrained_model_name_or_path: moonshotai/Kimi-VL-A3B-Instruct
71+
trust_remote_code: true
72+
73+
checkpoint:
74+
enabled: true
75+
checkpoint_dir: vlm_checkpoints/kimi2vl/
76+
model_save_format: safetensors
77+
save_consolidated: true
78+
79+
loss_fn:
80+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
81+
fp32_upcast: false
82+
83+
dataset:
84+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_cord_v2_dataset
85+
path_or_dataset: naver-clova-ix/cord-v2
86+
split: train
87+
88+
dataloader:
89+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
90+
num_workers: 0
91+
pin_memory: true
92+
collate_fn:
93+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.kimi_vl_collate_fn
94+
max_length: 2048
95+
96+
validation_dataset:
97+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_cord_v2_dataset
98+
path_or_dataset: naver-clova-ix/cord-v2
99+
split: validation
100+
101+
validation_dataloader:
102+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
103+
104+
optimizer:
105+
_target_: torch.optim.AdamW
106+
lr: 1.0e-05
107+
weight_decay: 0.01
108+
betas:
109+
- 0.9
110+
- 0.95
111+
112+
freeze_config:
113+
freeze_embeddings: true
114+
freeze_vision_tower: true
115+
freeze_language_model: false
116+
117+
# wandb:
118+
# project: <your_project_name>
119+
# entity: <your_entity_name>
120+
# name: kimi2vl_finetune

nemo_automodel/components/datasets/vlm/collate_fns.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,55 @@ def has_data(modality_list):
253253
return batch
254254

255255

256+
def kimi_vl_collate_fn(
257+
examples: Sequence[Dict[str, Any]],
258+
processor,
259+
max_length: Optional[int] = None,
260+
) -> Dict[str, torch.Tensor]:
261+
"""Collate function for KimiVL processors."""
262+
conversations = [example["conversation"] for example in examples]
263+
texts = [
264+
processor.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)
265+
for conversation in conversations
266+
]
267+
268+
images: List[Any] = []
269+
for conversation in conversations:
270+
for message in conversation:
271+
content = message.get("content")
272+
if isinstance(content, list):
273+
for item in content:
274+
if isinstance(item, dict) and item.get("type") == "image":
275+
images.append(item.get("image"))
276+
277+
processor_kwargs = {
278+
"text": texts,
279+
"return_tensors": "pt",
280+
"padding": True,
281+
"truncation": True,
282+
}
283+
if max_length is not None:
284+
processor_kwargs["max_length"] = max_length
285+
processor_kwargs["padding"] = "max_length"
286+
if images:
287+
processor_kwargs["images"] = images
288+
289+
batch = processor(**processor_kwargs)
290+
291+
labels = build_labels(
292+
batch["input_ids"],
293+
conversations,
294+
processor,
295+
)
296+
batch["labels"] = labels[:, 1:]
297+
298+
input_shape = batch["input_ids"].shape
299+
for key, value in list(batch.items()):
300+
if isinstance(value, torch.Tensor) and value.shape == input_shape:
301+
batch[key] = value[:, :-1]
302+
return batch
303+
304+
256305
def nemotron_parse_collate_fn(
257306
examples: Sequence[Dict[str, Any]],
258307
processor,
@@ -383,6 +432,7 @@ def default_collate_fn(
383432
COLLATE_FNS = {
384433
"Qwen2_5_VLProcessor": qwen2_5_collate_fn,
385434
"Qwen3OmniMoeProcessor": qwen3_omni_collate_fn,
435+
"KimiVLProcessor": kimi_vl_collate_fn,
386436
"NemotronParseProcessor": nemotron_parse_collate_fn,
387437
"default": default_collate_fn,
388438
}

nemo_automodel/components/models/deepseek_v3/model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,24 @@ def __init__(
155155

156156
def forward(
157157
self,
158-
input_ids: torch.Tensor,
158+
input_ids: torch.Tensor | None = None,
159159
*,
160+
inputs_embeds: torch.Tensor | None = None,
160161
position_ids: torch.Tensor | None = None,
161162
attention_mask: torch.Tensor | None = None,
162163
padding_mask: torch.Tensor | None = None,
163164
**attn_kwargs: Any,
164165
) -> tuple[torch.Tensor, torch.Tensor | None]:
166+
if (input_ids is None) == (inputs_embeds is None):
167+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
168+
169+
if inputs_embeds is None:
170+
inputs_embeds = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids
171+
165172
if position_ids is None:
173+
seq_len = inputs_embeds.shape[1]
166174
position_ids = (
167-
torch.arange(0, input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1)
175+
torch.arange(seq_len, device=inputs_embeds.device).unsqueeze(0).expand(inputs_embeds.shape[0], -1)
168176
)
169177

170178
with torch.no_grad():
@@ -176,7 +184,7 @@ def forward(
176184
cp_size=attn_kwargs.get("cp_size", 1),
177185
)
178186

179-
h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids
187+
h = inputs_embeds
180188

181189
# Apply the transformer layers.
182190
for layer in self.layers.values():
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)