Skip to content

Commit 4970858

Browse files
committed
nanovlm
1 parent 91558bb commit 4970858

File tree

7 files changed

+618
-0
lines changed

7 files changed

+618
-0
lines changed

examples/nanovlm/nanovlm_train.sh

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/bin/bash
2+
3+
DATASET_PATH="/mnt/umm/users/pufanyi/workspace/Show/lmms-engine/data/llava_next.yaml"
4+
PROCESSOR_NAME="/mnt/umm/users/pufanyi/workspace/Show/CKPT/Qwen/Qwen3-0.6B"
5+
MODEL_PATH="/mnt/umm/users/pufanyi/workspace/Show/CKPT/Qwen/Qwen3-0.6B"
6+
SIGLIP_PROCESSOR="/mnt/umm/users/pufanyi/workspace/Show/CKPT/google/siglip2-so400m-patch16-naflex"
7+
8+
ATTN_IMPLEMENTATION="flash_attention_2"
9+
PER_DEVICE_TRAIN_BATCH_SIZE=16
10+
LEARNING_RATE=2.0e-04
11+
WEIGHT_DECAY=0.0
12+
GRADIENT_ACCUMULATION_STEPS=1
13+
GRADIENT_CHECKPOINTING=true
14+
NUM_TRAIN_EPOCHS=1
15+
RUN_NAME="debug_nanovlm"
16+
OUTPUT_DIR="./output/debug_nanovlm"
17+
WARMUP_RATIO=0.1
18+
MAX_STEPS=10000
19+
20+
IMAGE_TOKEN_ID=151655
21+
22+
torchrun --nproc_per_node="8" \
23+
--nnodes="1" \
24+
--node_rank="0" \
25+
--master_addr="127.0.0.1" \
26+
--master_port="8000" \
27+
-m lmms_engine.launch.cli \
28+
trainer_type=fsdp2_trainer \
29+
dataset_config.dataset_path=${DATASET_PATH} \
30+
dataset_config.dataset_format=yaml \
31+
dataset_config.dataset_type=qwen3_vl_iterable \
32+
dataset_config.processor_config.processor_type=nanovlm \
33+
dataset_config.processor_config.processor_name=${PROCESSOR_NAME} \
34+
+dataset_config.processor_config.extra_kwargs.image_processor_name=${SIGLIP_PROCESSOR} \
35+
+dataset_config.processor_config.extra_kwargs.image_token_count=256 \
36+
dataset_config.packing=false \
37+
dataset_config.packing_strategy=first_fit \
38+
dataset_config.packing_length=51200 \
39+
dataset_config.filter_overlong=true \
40+
dataset_config.video_backend=qwen_vl_utils \
41+
dataset_config.video_sampling_strategy=fps \
42+
dataset_config.video_max_pixels=50176 \
43+
dataset_config.video_max_frames=512 \
44+
+model_config.load_from_config.model_type=nanovlm \
45+
model_config.load_from_pretrained_path=null \
46+
+model_config.load_from_config.config.llm_model_name=${MODEL_PATH} \
47+
+model_config.load_from_config.config.vision_model_name=${SIGLIP_PROCESSOR} \
48+
+model_config.load_from_config.config.image_token_id=${IMAGE_TOKEN_ID} \
49+
+model_config.load_from_config.config.vision_feature_dim=1152 \
50+
+model_config.load_from_config.config.image_token_count=256 \
51+
model_config.attn_implementation=${ATTN_IMPLEMENTATION} \
52+
trainer_args.freeze_modules=["visual"] \
53+
trainer_args.per_device_train_batch_size=${PER_DEVICE_TRAIN_BATCH_SIZE} \
54+
trainer_args.learning_rate=${LEARNING_RATE} \
55+
trainer_args.weight_decay=${WEIGHT_DECAY} \
56+
trainer_args.gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \
57+
trainer_args.gradient_checkpointing=${GRADIENT_CHECKPOINTING} \
58+
trainer_args.num_train_epochs=${NUM_TRAIN_EPOCHS} \
59+
trainer_args.warmup_ratio=${WARMUP_RATIO} \
60+
trainer_args.run_name=${RUN_NAME} \
61+
trainer_args.output_dir=${OUTPUT_DIR} \
62+
trainer_args.fsdp2=true \
63+
trainer_args.max_steps=${MAX_STEPS} \
64+
trainer_args.fsdp_config.transformer_layer_cls_to_wrap=["Qwen3DecoderLayer"] \
65+
trainer_args.fsdp_config.reshard_after_forward=false \
66+
trainer_args.sp_ulysses_degree=1 \
67+
trainer_args.use_liger_kernel=true \
68+
trainer_args.use_rmpad=true \
69+
trainer_args.dataloader_num_workers=0 \
70+
trainer_args.dataloader_prefetch_factor=null \
71+
trainer_args.print_batch_input_steps=5 \
72+
trainer_args.bf16=true \
73+
trainer_args.lr_scheduler_type=cosine \
74+
trainer_args.logging_steps=1 \
75+
trainer_args.group_by_length=false \
76+
trainer_args.bf16=true

src/lmms_engine/datasets/processor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .config import ProcessorConfig
55
from .llava_processor import LLaVADataProcessor
66
from .llava_video_processor import LLaVAVideoDataProcessor
7+
from .nanovlm_processor import NanovlmDataProcessor
78
from .pure_text_processor import PureTextDataProcessor
89
from .qwen2_5_omni_processor import Qwen2_5OmniDataProcessor
910
from .qwen2_5_vl_processor import Qwen2_5_VLDataProcessor
@@ -21,6 +22,7 @@
2122
"BaseQwen2_5_DataProcessor",
2223
"LLaVADataProcessor",
2324
"LLaVAVideoDataProcessor",
25+
"NanovlmDataProcessor",
2426
"Qwen2_5OmniDataProcessor",
2527
"Qwen3OmniMoeDataProcessor",
2628
"Qwen2_5_VLDataProcessor",
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
from types import SimpleNamespace
2+
from typing import List, Optional, Sequence, Tuple
3+
4+
import numpy as np
5+
import torch
6+
from PIL import Image
7+
from transformers import AutoProcessor, AutoTokenizer
8+
9+
from lmms_engine.mapping_func import register_processor
10+
11+
from .config import ProcessorConfig
12+
13+
14+
@register_processor("nanovlm")
15+
class NanovlmDataProcessor:
16+
def __init__(self, config: ProcessorConfig) -> None:
17+
self.config = config
18+
19+
def build(self):
20+
self._tokenizer = AutoTokenizer.from_pretrained(self.config.processor_name)
21+
22+
image_processor_name = self.config.extra_kwargs.get(
23+
"image_processor_name", "google/siglip2-base-patch16-naflex"
24+
)
25+
self.image_processor = AutoProcessor.from_pretrained(image_processor_name)
26+
27+
self.image_token = self.config.extra_kwargs.get("image_token", "<|image_pad|>")
28+
self.video_token = self.config.extra_kwargs.get("video_token", "<|video_pad|>")
29+
30+
for name, token in (("image_token", self.image_token), ("video_token", self.video_token)):
31+
if token not in self._tokenizer.get_vocab():
32+
raise ValueError(f"{name} {token} not found in tokenizer vocab. Please use a Qwen3 token.")
33+
34+
self.processor = SimpleNamespace(
35+
tokenizer=self._tokenizer,
36+
image_token=self.image_token,
37+
video_token=self.video_token,
38+
batch_decode=self._tokenizer.batch_decode,
39+
)
40+
41+
def save_pretrained(self, output_dir: str):
42+
self._tokenizer.save_pretrained(output_dir)
43+
self.image_processor.save_pretrained(output_dir)
44+
45+
def process(
46+
self,
47+
images: Optional[List[Image.Image]],
48+
hf_messages,
49+
videos=None,
50+
system_message: str = "You are a helpful assistant",
51+
add_system_prompt=True,
52+
add_generation_prompt=False,
53+
**kwargs,
54+
):
55+
flat_images, num_image_tokens, num_video_tokens = self._prepare_visual_inputs(
56+
images=images,
57+
videos=videos,
58+
hf_messages=hf_messages,
59+
)
60+
61+
if flat_images is not None and len(flat_images) > 0:
62+
image_inputs = self.image_processor(images=flat_images, return_tensors="pt")
63+
else:
64+
image_inputs = {}
65+
num_image_tokens = None
66+
num_video_tokens = None
67+
68+
inputs = self.get_qwen_template_labels(
69+
hf_messages,
70+
num_image_tokens,
71+
num_video_tokens,
72+
system_message=system_message,
73+
add_system_prompt=add_system_prompt,
74+
add_generation_prompt=add_generation_prompt,
75+
)
76+
for key in ("pixel_values", "pixel_attention_mask", "spatial_shapes"):
77+
if key in image_inputs:
78+
inputs[key] = image_inputs[key]
79+
return inputs
80+
81+
def get_qwen_template_labels(
82+
self,
83+
hf_messages,
84+
num_image_tokens: Optional[List[int]],
85+
num_video_tokens: Optional[List[int]],
86+
system_message: str = "You are a helpful assistant",
87+
add_system_prompt: bool = True,
88+
add_generation_prompt: bool = False,
89+
):
90+
template_messages = self._normalize_messages_for_template(hf_messages)
91+
special_tokens = list(self._tokenizer.additional_special_tokens)
92+
special_tokens.extend(["<|im_start|>", "<|im_end|>"])
93+
unmask_tokens_idx = [self._tokenizer.convert_tokens_to_ids(t) for t in special_tokens]
94+
input_id, target = [], []
95+
image_start_from = 0
96+
video_start_from = 0
97+
if add_system_prompt and template_messages[0]["role"] != "system":
98+
input_id += self._apply_chat_template(
99+
[{"role": "system", "content": system_message}], tokenize=True
100+
)
101+
target += [-100] * len(input_id)
102+
103+
for message in template_messages:
104+
role = message["role"]
105+
encode_id = self._apply_chat_template([message], tokenize=True)
106+
if self.image_token_id in encode_id and num_image_tokens is not None:
107+
encode_id, used_images = self._expand_encode_id_image_tokens(
108+
encode_id, num_image_tokens, image_start_from
109+
)
110+
image_start_from += used_images
111+
if self.video_token_id in encode_id and num_video_tokens is not None:
112+
encode_id, used_videos = self._expand_encode_id_video_tokens(
113+
encode_id, num_video_tokens, video_start_from
114+
)
115+
video_start_from += used_videos
116+
# Nanovlm only understands image_token_id, map video tokens to image tokens
117+
encode_id = [self.image_token_id if t == self.video_token_id else t for t in encode_id]
118+
input_id += encode_id
119+
if role in ["user", "system"]:
120+
target += [-100] * len(encode_id)
121+
else:
122+
encode_id_copy = list(encode_id)
123+
encode_id_copy[:3] = [-100] * 3
124+
target += encode_id_copy
125+
126+
if add_generation_prompt:
127+
generation_tokens = self._tokenizer.encode("<|im_start|>assistant\n")
128+
input_id += generation_tokens
129+
target += [-100] * len(generation_tokens)
130+
131+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
132+
for idx, encode_id in enumerate(input_id):
133+
if encode_id in unmask_tokens_idx:
134+
target[idx] = encode_id
135+
if encode_id == self.image_token_id:
136+
target[idx] = -100
137+
138+
input_id = torch.tensor(input_id, dtype=torch.long)
139+
target = torch.tensor(target, dtype=torch.long)
140+
return dict(
141+
input_ids=input_id,
142+
labels=target,
143+
)
144+
145+
def _expand_encode_id_image_tokens(
146+
self,
147+
encode_id: List[int],
148+
image_token_num: List[int],
149+
start_from: int = 0,
150+
):
151+
image_pos = [i for i, x in enumerate(encode_id) if x == self.image_token_id]
152+
expanded_encode_id = []
153+
prev = 0
154+
for idx, pos in enumerate(image_pos):
155+
expanded_encode_id.extend(encode_id[prev:pos])
156+
expanded_encode_id.extend([self.image_token_id] * image_token_num[idx + start_from])
157+
prev = pos + 1
158+
159+
if idx == len(image_pos) - 1:
160+
expanded_encode_id.extend(encode_id[prev:])
161+
162+
return expanded_encode_id, len(image_pos)
163+
164+
def _expand_encode_id_video_tokens(
165+
self,
166+
encode_id: List[int],
167+
video_token_num: List[int],
168+
start_from: int = 0,
169+
):
170+
video_pos = [i for i, x in enumerate(encode_id) if x == self.video_token_id]
171+
expanded_encode_id = []
172+
prev = 0
173+
for idx, pos in enumerate(video_pos):
174+
expanded_encode_id.extend(encode_id[prev:pos])
175+
expanded_encode_id.extend([self.video_token_id] * video_token_num[idx + start_from])
176+
prev = pos + 1
177+
178+
if idx == len(video_pos) - 1:
179+
expanded_encode_id.extend(encode_id[prev:])
180+
181+
return expanded_encode_id, len(video_pos)
182+
183+
def _apply_chat_template(self, messages, tokenize: bool = False):
184+
result = self._tokenizer.apply_chat_template(messages, tokenize=tokenize)
185+
if isinstance(result, list) and result and isinstance(result[0], list):
186+
return result[0]
187+
return result
188+
189+
def _prepare_visual_inputs(
190+
self,
191+
images: Optional[List[Image.Image]],
192+
videos: Optional[Sequence],
193+
hf_messages,
194+
) -> Tuple[Optional[List], Optional[List[int]], Optional[List[int]]]:
195+
if images is None and videos is None:
196+
return None, None, None
197+
198+
image_token_count = self.config.extra_kwargs.get("image_token_count", 256)
199+
flat_images: List = []
200+
num_image_tokens: List[int] = []
201+
num_video_tokens: List[int] = []
202+
203+
image_idx = 0
204+
video_idx = 0
205+
206+
for message in hf_messages:
207+
content = message.get("content", [])
208+
if not isinstance(content, list):
209+
continue
210+
for item in content:
211+
item_type = item.get("type")
212+
if item_type == "image":
213+
if images is None or image_idx >= len(images):
214+
raise ValueError("Missing image input for <image> placeholder.")
215+
flat_images.append(self._to_pil_image(images[image_idx]))
216+
num_image_tokens.append(image_token_count)
217+
image_idx += 1
218+
elif item_type == "video":
219+
if videos is None or video_idx >= len(videos):
220+
raise ValueError("Missing video input for <video> placeholder.")
221+
frames = self._normalize_video_frames(videos[video_idx])
222+
flat_images.extend([self._to_pil_image(frame) for frame in frames])
223+
num_video_tokens.append(image_token_count * len(frames))
224+
video_idx += 1
225+
226+
if len(flat_images) == 0:
227+
return None, None, None
228+
229+
return flat_images, (num_image_tokens or None), (num_video_tokens or None)
230+
231+
def _normalize_video_frames(self, video) -> List:
232+
if isinstance(video, list):
233+
return video
234+
if isinstance(video, np.ndarray):
235+
if video.ndim == 3:
236+
return [video]
237+
if video.ndim == 4:
238+
return [frame for frame in video]
239+
if torch.is_tensor(video):
240+
video_np = video.detach().cpu().numpy()
241+
if video_np.ndim == 3:
242+
return [video_np]
243+
if video_np.ndim == 4:
244+
return [frame for frame in video_np]
245+
raise ValueError(f"Unsupported video format: {type(video)}")
246+
247+
def _to_pil_image(self, image: object) -> Image.Image:
248+
if isinstance(image, Image.Image):
249+
return image.convert("RGB")
250+
if torch.is_tensor(image):
251+
image = image.detach().cpu().numpy()
252+
if isinstance(image, np.ndarray):
253+
arr = image
254+
if arr.ndim == 2:
255+
arr = arr[:, :, None]
256+
if arr.ndim != 3:
257+
raise ValueError(f"Unsupported image shape: {arr.shape}")
258+
# If channel-first, transpose to HWC
259+
if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
260+
arr = np.transpose(arr, (1, 2, 0))
261+
if arr.dtype != np.uint8:
262+
max_val = float(arr.max()) if arr.size > 0 else 1.0
263+
if max_val <= 1.0:
264+
arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
265+
else:
266+
arr = arr.clip(0, 255).astype(np.uint8)
267+
pil = Image.fromarray(arr)
268+
return pil.convert("RGB")
269+
raise ValueError(f"Unsupported image type: {type(image)}")
270+
271+
def _normalize_messages_for_template(self, hf_messages):
272+
normalized = []
273+
for message in hf_messages:
274+
content = message.get("content")
275+
if isinstance(content, list):
276+
parts = []
277+
for item in content:
278+
if not isinstance(item, dict):
279+
parts.append(str(item))
280+
continue
281+
item_type = item.get("type")
282+
if item_type in ["image", "image_url"] or "image" in item:
283+
parts.append("<|vision_start|><|image_pad|><|vision_end|>\n")
284+
elif item_type in ["video", "video_url"] or "video" in item:
285+
parts.append("<|vision_start|><|video_pad|><|vision_end|>\n")
286+
elif item_type in ["audio", "audio_url"] or "audio" in item:
287+
parts.append("<|AUDIO|>\n")
288+
elif "text" in item:
289+
parts.append(item["text"])
290+
normalized.append({"role": message["role"], "content": "".join(parts)})
291+
else:
292+
normalized.append(message)
293+
return normalized
294+
295+
@property
296+
def image_token_id(self):
297+
return self._tokenizer.convert_tokens_to_ids(self.image_token)
298+
299+
@property
300+
def video_token_id(self):
301+
return self._tokenizer.convert_tokens_to_ids(self.video_token)
302+
303+
@property
304+
def tokenizer(self):
305+
return self._tokenizer

0 commit comments

Comments
 (0)