Skip to content

Commit 42dd631

Browse files
authored
[feat] NanoVLM Training support (#134)
* nanovlm * tokenizer_init * trasnformer-style * style: fix linting issues with black and isort
1 parent 91558bb commit 42dd631

File tree

7 files changed

+608
-0
lines changed

7 files changed

+608
-0
lines changed

examples/nanovlm/nanovlm_train.sh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/bin/bash
2+
DATASET_PATH="/path/to/dataset.yaml"
3+
PROCESSOR_NAME="LMMs-Lab-Speedrun/NanoVLM_Init"
4+
MODEL_PATH="LMMs-Lab-Speedrun/NanoVLM_Init"
5+
ATTN_IMPLEMENTATION="flash_attention_2"
6+
PER_DEVICE_TRAIN_BATCH_SIZE=2
7+
LEARNING_RATE=2.0e-04
8+
WEIGHT_DECAY=0.0
9+
GRADIENT_ACCUMULATION_STEPS=1
10+
GRADIENT_CHECKPOINTING=true
11+
NUM_TRAIN_EPOCHS=1
12+
RUN_NAME="debug_nanovlm"
13+
OUTPUT_DIR="./output/debug_nanovlm"
14+
WARMUP_RATIO=0.1
15+
MAX_STEPS=10000
16+
17+
torchrun --nproc_per_node="2" \
18+
--nnodes="1" \
19+
--node_rank="0" \
20+
--master_addr="127.0.0.1" \
21+
--master_port="8000" \
22+
-m lmms_engine.launch.cli \
23+
trainer_type=fsdp2_trainer \
24+
dataset_config.dataset_path=${DATASET_PATH} \
25+
dataset_config.dataset_format=yaml \
26+
dataset_config.dataset_type=qwen3_vl_iterable \
27+
dataset_config.processor_config.processor_type=nanovlm \
28+
dataset_config.processor_config.processor_name=${PROCESSOR_NAME} \
29+
dataset_config.packing=false \
30+
dataset_config.packing_strategy=first_fit \
31+
dataset_config.packing_length=51200 \
32+
dataset_config.filter_overlong=true \
33+
dataset_config.video_backend=qwen_vl_utils \
34+
dataset_config.video_sampling_strategy=fps \
35+
dataset_config.video_max_pixels=50176 \
36+
dataset_config.video_max_frames=512 \
37+
model_config.load_from_pretrained_path=${MODEL_PATH} \
38+
model_config.attn_implementation=${ATTN_IMPLEMENTATION} \
39+
trainer_args.freeze_modules=["vision_model"] \
40+
trainer_args.per_device_train_batch_size=${PER_DEVICE_TRAIN_BATCH_SIZE} \
41+
trainer_args.learning_rate=${LEARNING_RATE} \
42+
trainer_args.weight_decay=${WEIGHT_DECAY} \
43+
trainer_args.gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
44+
trainer_args.gradient_checkpointing=${GRADIENT_CHECKPOINTING} \
45+
trainer_args.num_train_epochs=${NUM_TRAIN_EPOCHS} \
46+
trainer_args.warmup_ratio=${WARMUP_RATIO} \
47+
trainer_args.run_name=${RUN_NAME} \
48+
trainer_args.output_dir=${OUTPUT_DIR} \
49+
trainer_args.fsdp2=true \
50+
trainer_args.max_steps=${MAX_STEPS} \
51+
trainer_args.fsdp_config.transformer_layer_cls_to_wrap=["Qwen3DecoderLayer"] \
52+
trainer_args.fsdp_config.reshard_after_forward=false \
53+
trainer_args.sp_ulysses_degree=1 \
54+
trainer_args.use_liger_kernel=true \
55+
trainer_args.use_rmpad=true \
56+
trainer_args.dataloader_num_workers=0 \
57+
trainer_args.dataloader_prefetch_factor=null \
58+
trainer_args.print_batch_input_steps=5 \
59+
trainer_args.bf16=true \
60+
trainer_args.lr_scheduler_type=cosine \
61+
trainer_args.logging_steps=1 \
62+
trainer_args.report_to=[] \
63+
trainer_args.group_by_length=false \
64+
trainer_args.bf16=true

src/lmms_engine/datasets/processor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .config import ProcessorConfig
55
from .llava_processor import LLaVADataProcessor
66
from .llava_video_processor import LLaVAVideoDataProcessor
7+
from .nanovlm_processor import NanovlmDataProcessor
78
from .pure_text_processor import PureTextDataProcessor
89
from .qwen2_5_omni_processor import Qwen2_5OmniDataProcessor
910
from .qwen2_5_vl_processor import Qwen2_5_VLDataProcessor
@@ -21,6 +22,7 @@
2122
"BaseQwen2_5_DataProcessor",
2223
"LLaVADataProcessor",
2324
"LLaVAVideoDataProcessor",
25+
"NanovlmDataProcessor",
2426
"Qwen2_5OmniDataProcessor",
2527
"Qwen3OmniMoeDataProcessor",
2628
"Qwen2_5_VLDataProcessor",
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
from types import SimpleNamespace
2+
from typing import List, Optional, Sequence, Tuple
3+
4+
import numpy as np
5+
import torch
6+
from PIL import Image
7+
from transformers import AutoProcessor, AutoTokenizer
8+
9+
from lmms_engine.mapping_func import register_processor
10+
11+
from .config import ProcessorConfig
12+
13+
14+
@register_processor("nanovlm")
15+
class NanovlmDataProcessor:
16+
def __init__(self, config: ProcessorConfig) -> None:
17+
self.config = config
18+
19+
def build(self):
20+
self._tokenizer = AutoTokenizer.from_pretrained(self.config.processor_name)
21+
22+
# Load image processor from the same local/remote checkpoint as tokenizer.
23+
# `NanoVLM_init` now carries both tokenizer and preprocessor configs.
24+
loaded_processor = AutoProcessor.from_pretrained(self.config.processor_name)
25+
self.image_processor = getattr(loaded_processor, "image_processor", loaded_processor)
26+
27+
self.image_token = self.config.extra_kwargs.get("image_token", "<|image_pad|>")
28+
self.video_token = self.config.extra_kwargs.get("video_token", "<|video_pad|>")
29+
30+
for name, token in (("image_token", self.image_token), ("video_token", self.video_token)):
31+
if token not in self._tokenizer.get_vocab():
32+
raise ValueError(f"{name} {token} not found in tokenizer vocab. Please use a Qwen3 token.")
33+
34+
self.processor = SimpleNamespace(
35+
tokenizer=self._tokenizer,
36+
image_token=self.image_token,
37+
video_token=self.video_token,
38+
batch_decode=self._tokenizer.batch_decode,
39+
)
40+
41+
def save_pretrained(self, output_dir: str):
42+
self._tokenizer.save_pretrained(output_dir)
43+
self.image_processor.save_pretrained(output_dir)
44+
45+
def process(
46+
self,
47+
images: Optional[List[Image.Image]],
48+
hf_messages,
49+
videos=None,
50+
system_message: str = "You are a helpful assistant",
51+
add_system_prompt=True,
52+
add_generation_prompt=False,
53+
**kwargs,
54+
):
55+
flat_images, num_image_tokens, num_video_tokens = self._prepare_visual_inputs(
56+
images=images,
57+
videos=videos,
58+
hf_messages=hf_messages,
59+
)
60+
61+
if flat_images is not None and len(flat_images) > 0:
62+
image_inputs = self.image_processor(images=flat_images, return_tensors="pt")
63+
else:
64+
image_inputs = {}
65+
num_image_tokens = None
66+
num_video_tokens = None
67+
68+
inputs = self.get_qwen_template_labels(
69+
hf_messages,
70+
num_image_tokens,
71+
num_video_tokens,
72+
system_message=system_message,
73+
add_system_prompt=add_system_prompt,
74+
add_generation_prompt=add_generation_prompt,
75+
)
76+
for key in ("pixel_values", "pixel_attention_mask", "spatial_shapes"):
77+
if key in image_inputs:
78+
inputs[key] = image_inputs[key]
79+
return inputs
80+
81+
def get_qwen_template_labels(
82+
self,
83+
hf_messages,
84+
num_image_tokens: Optional[List[int]],
85+
num_video_tokens: Optional[List[int]],
86+
system_message: str = "You are a helpful assistant",
87+
add_system_prompt: bool = True,
88+
add_generation_prompt: bool = False,
89+
):
90+
special_tokens = list(self._tokenizer.additional_special_tokens)
91+
special_tokens.extend(["<|im_start|>", "<|im_end|>"])
92+
unmask_tokens_idx = [self._tokenizer.convert_tokens_to_ids(t) for t in special_tokens]
93+
input_id, target = [], []
94+
image_start_from = 0
95+
video_start_from = 0
96+
if add_system_prompt and hf_messages[0]["role"] != "system":
97+
input_id += self._apply_chat_template([{"role": "system", "content": system_message}], tokenize=True)
98+
target += [-100] * len(input_id)
99+
100+
for message in hf_messages:
101+
role = message["role"]
102+
encode_id = self._apply_chat_template([message], tokenize=True)
103+
if self.image_token_id in encode_id and num_image_tokens is not None:
104+
encode_id, used_images = self._expand_encode_id_image_tokens(
105+
encode_id, num_image_tokens, image_start_from
106+
)
107+
image_start_from += used_images
108+
if self.video_token_id in encode_id and num_video_tokens is not None:
109+
encode_id, used_videos = self._expand_encode_id_video_tokens(
110+
encode_id, num_video_tokens, video_start_from
111+
)
112+
video_start_from += used_videos
113+
# Nanovlm only understands image_token_id, map video tokens to image tokens
114+
encode_id = [self.image_token_id if t == self.video_token_id else t for t in encode_id]
115+
input_id += encode_id
116+
if role in ["user", "system"]:
117+
target += [-100] * len(encode_id)
118+
else:
119+
encode_id_copy = list(encode_id)
120+
encode_id_copy[:3] = [-100] * 3
121+
target += encode_id_copy
122+
123+
if add_generation_prompt:
124+
generation_tokens = self._tokenizer.encode("<|im_start|>assistant\n")
125+
input_id += generation_tokens
126+
target += [-100] * len(generation_tokens)
127+
128+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
129+
for idx, encode_id in enumerate(input_id):
130+
if encode_id in unmask_tokens_idx:
131+
target[idx] = encode_id
132+
if encode_id == self.image_token_id:
133+
target[idx] = -100
134+
135+
input_id = torch.tensor(input_id, dtype=torch.long)
136+
target = torch.tensor(target, dtype=torch.long)
137+
return dict(
138+
input_ids=input_id,
139+
labels=target,
140+
)
141+
142+
def _expand_encode_id_image_tokens(
143+
self,
144+
encode_id: List[int],
145+
image_token_num: List[int],
146+
start_from: int = 0,
147+
):
148+
image_pos = [i for i, x in enumerate(encode_id) if x == self.image_token_id]
149+
expanded_encode_id = []
150+
prev = 0
151+
for idx, pos in enumerate(image_pos):
152+
expanded_encode_id.extend(encode_id[prev:pos])
153+
expanded_encode_id.extend([self.image_token_id] * image_token_num[idx + start_from])
154+
prev = pos + 1
155+
156+
if idx == len(image_pos) - 1:
157+
expanded_encode_id.extend(encode_id[prev:])
158+
159+
return expanded_encode_id, len(image_pos)
160+
161+
def _expand_encode_id_video_tokens(
162+
self,
163+
encode_id: List[int],
164+
video_token_num: List[int],
165+
start_from: int = 0,
166+
):
167+
video_pos = [i for i, x in enumerate(encode_id) if x == self.video_token_id]
168+
expanded_encode_id = []
169+
prev = 0
170+
for idx, pos in enumerate(video_pos):
171+
expanded_encode_id.extend(encode_id[prev:pos])
172+
expanded_encode_id.extend([self.video_token_id] * video_token_num[idx + start_from])
173+
prev = pos + 1
174+
175+
if idx == len(video_pos) - 1:
176+
expanded_encode_id.extend(encode_id[prev:])
177+
178+
return expanded_encode_id, len(video_pos)
179+
180+
def _apply_chat_template(self, messages, tokenize: bool = False):
181+
result = self._tokenizer.apply_chat_template(messages, tokenize=tokenize)
182+
if isinstance(result, list) and result and isinstance(result[0], list):
183+
return result[0]
184+
return result
185+
186+
def _prepare_visual_inputs(
187+
self,
188+
images: Optional[List[Image.Image]],
189+
videos: Optional[Sequence],
190+
hf_messages,
191+
) -> Tuple[Optional[List], Optional[List[int]], Optional[List[int]]]:
192+
if images is None and videos is None:
193+
return None, None, None
194+
195+
image_token_count = self.config.extra_kwargs.get(
196+
"image_token_count",
197+
getattr(self.image_processor, "max_num_patches", 256),
198+
)
199+
flat_images: List = []
200+
num_image_tokens: List[int] = []
201+
num_video_tokens: List[int] = []
202+
203+
image_idx = 0
204+
video_idx = 0
205+
206+
for message in hf_messages:
207+
content = message.get("content", [])
208+
if not isinstance(content, list):
209+
continue
210+
for item in content:
211+
item_type = item.get("type")
212+
if item_type == "image":
213+
if images is None or image_idx >= len(images):
214+
raise ValueError("Missing image input for <image> placeholder.")
215+
flat_images.append(self._to_pil_image(images[image_idx]))
216+
num_image_tokens.append(image_token_count)
217+
image_idx += 1
218+
elif item_type == "video":
219+
if videos is None or video_idx >= len(videos):
220+
raise ValueError("Missing video input for <video> placeholder.")
221+
frames = self._normalize_video_frames(videos[video_idx])
222+
flat_images.extend([self._to_pil_image(frame) for frame in frames])
223+
num_video_tokens.append(image_token_count * len(frames))
224+
video_idx += 1
225+
226+
if len(flat_images) == 0:
227+
return None, None, None
228+
229+
return flat_images, (num_image_tokens or None), (num_video_tokens or None)
230+
231+
def _normalize_video_frames(self, video) -> List:
232+
if isinstance(video, list):
233+
return video
234+
if isinstance(video, np.ndarray):
235+
if video.ndim == 3:
236+
return [video]
237+
if video.ndim == 4:
238+
return [frame for frame in video]
239+
if torch.is_tensor(video):
240+
video_np = video.detach().cpu().numpy()
241+
if video_np.ndim == 3:
242+
return [video_np]
243+
if video_np.ndim == 4:
244+
return [frame for frame in video_np]
245+
raise ValueError(f"Unsupported video format: {type(video)}")
246+
247+
def _to_pil_image(self, image: object) -> Image.Image:
248+
if isinstance(image, Image.Image):
249+
return image.convert("RGB")
250+
if torch.is_tensor(image):
251+
image = image.detach().cpu().numpy()
252+
if isinstance(image, np.ndarray):
253+
arr = image
254+
if arr.ndim == 2:
255+
arr = arr[:, :, None]
256+
if arr.ndim != 3:
257+
raise ValueError(f"Unsupported image shape: {arr.shape}")
258+
# If channel-first, transpose to HWC
259+
if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
260+
arr = np.transpose(arr, (1, 2, 0))
261+
if arr.dtype != np.uint8:
262+
max_val = float(arr.max()) if arr.size > 0 else 1.0
263+
if max_val <= 1.0:
264+
arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
265+
else:
266+
arr = arr.clip(0, 255).astype(np.uint8)
267+
pil = Image.fromarray(arr)
268+
return pil.convert("RGB")
269+
raise ValueError(f"Unsupported image type: {type(image)}")
270+
271+
@property
272+
def image_token_id(self):
273+
return self._tokenizer.convert_tokens_to_ids(self.image_token)
274+
275+
@property
276+
def video_token_id(self):
277+
return self._tokenizer.convert_tokens_to_ids(self.video_token)
278+
279+
@property
280+
def tokenizer(self):
281+
return self._tokenizer

src/lmms_engine/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
apply_liger_kernel_to_llava_onevision1_5,
1111
)
1212
from .monkey_patch import MONKEY_PATCHER
13+
from .nanovlm import NanovlmConfig, NanovlmForConditionalGeneration
1314
from .qwen2 import apply_liger_kernel_to_qwen2
1415
from .qwen2_5_omni import (
1516
Qwen2_5OmniThinkerConfig,
@@ -67,6 +68,8 @@
6768
"LLaDADLLMConfig",
6869
"LLaDADLLMForMaskedLM",
6970
"MONKEY_PATCHER",
71+
"NanovlmConfig",
72+
"NanovlmForConditionalGeneration",
7073
"RaeSiglipConfig",
7174
"RaeSiglipModel",
7275
"SiTModel",
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from lmms_engine.mapping_func import register_model
2+
3+
from .configuration_nanovlm import NanovlmConfig
4+
from .modeling_nanovlm import NanovlmForConditionalGeneration
5+
6+
register_model(
7+
"nanovlm",
8+
NanovlmConfig,
9+
NanovlmForConditionalGeneration,
10+
model_general_type="image_text_to_text",
11+
)
12+
13+
__all__ = [
14+
"NanovlmConfig",
15+
"NanovlmForConditionalGeneration",
16+
]

0 commit comments

Comments
 (0)