-
Notifications
You must be signed in to change notification settings - Fork 32
[feat] NanoVLM Training support #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| #!/bin/bash | ||
|
|
||
| DATASET_PATH="/mnt/umm/users/pufanyi/workspace/Show/lmms-engine/data/llava_next.yaml" | ||
| PROCESSOR_NAME="/mnt/umm/users/pufanyi/workspace/Show/CKPT/Qwen/Qwen3-0.6B" | ||
| MODEL_PATH="/mnt/umm/users/pufanyi/workspace/Show/CKPT/Qwen/Qwen3-0.6B" | ||
| SIGLIP_PROCESSOR="/mnt/umm/users/pufanyi/workspace/Show/CKPT/google/siglip2-so400m-patch16-naflex" | ||
|
|
||
| ATTN_IMPLEMENTATION="flash_attention_2" | ||
| PER_DEVICE_TRAIN_BATCH_SIZE=16 | ||
| LEARNING_RATE=2.0e-04 | ||
| WEIGHT_DECAY=0.0 | ||
| GRADIENT_ACCUMULATION_STEPS=1 | ||
| GRADIENT_CHECKPOINTING=true | ||
| NUM_TRAIN_EPOCHS=1 | ||
| RUN_NAME="debug_nanovlm" | ||
| OUTPUT_DIR="./output/debug_nanovlm" | ||
| WARMUP_RATIO=0.1 | ||
| MAX_STEPS=10000 | ||
|
|
||
| IMAGE_TOKEN_ID=151655 | ||
|
|
||
| torchrun --nproc_per_node="8" \ | ||
| --nnodes="1" \ | ||
| --node_rank="0" \ | ||
| --master_addr="127.0.0.1" \ | ||
| --master_port="8000" \ | ||
| -m lmms_engine.launch.cli \ | ||
| trainer_type=fsdp2_trainer \ | ||
| dataset_config.dataset_path=${DATASET_PATH} \ | ||
| dataset_config.dataset_format=yaml \ | ||
| dataset_config.dataset_type=qwen3_vl_iterable \ | ||
| dataset_config.processor_config.processor_type=nanovlm \ | ||
| dataset_config.processor_config.processor_name=${PROCESSOR_NAME} \ | ||
| +dataset_config.processor_config.extra_kwargs.image_processor_name=${SIGLIP_PROCESSOR} \ | ||
| +dataset_config.processor_config.extra_kwargs.image_token_count=256 \ | ||
| dataset_config.packing=false \ | ||
| dataset_config.packing_strategy=first_fit \ | ||
| dataset_config.packing_length=51200 \ | ||
| dataset_config.filter_overlong=true \ | ||
| dataset_config.video_backend=qwen_vl_utils \ | ||
| dataset_config.video_sampling_strategy=fps \ | ||
| dataset_config.video_max_pixels=50176 \ | ||
| dataset_config.video_max_frames=512 \ | ||
| +model_config.load_from_config.model_type=nanovlm \ | ||
| model_config.load_from_pretrained_path=null \ | ||
| +model_config.load_from_config.config.llm_model_name=${MODEL_PATH} \ | ||
| +model_config.load_from_config.config.vision_model_name=${SIGLIP_PROCESSOR} \ | ||
| +model_config.load_from_config.config.image_token_id=${IMAGE_TOKEN_ID} \ | ||
| +model_config.load_from_config.config.vision_feature_dim=1152 \ | ||
| +model_config.load_from_config.config.image_token_count=256 \ | ||
| model_config.attn_implementation=${ATTN_IMPLEMENTATION} \ | ||
| trainer_args.freeze_modules=["visual"] \ | ||
| trainer_args.per_device_train_batch_size=${PER_DEVICE_TRAIN_BATCH_SIZE} \ | ||
| trainer_args.learning_rate=${LEARNING_RATE} \ | ||
| trainer_args.weight_decay=${WEIGHT_DECAY} \ | ||
| trainer_args.gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ | ||
| trainer_args.gradient_checkpointing=${GRADIENT_CHECKPOINTING} \ | ||
| trainer_args.num_train_epochs=${NUM_TRAIN_EPOCHS} \ | ||
| trainer_args.warmup_ratio=${WARMUP_RATIO} \ | ||
| trainer_args.run_name=${RUN_NAME} \ | ||
| trainer_args.output_dir=${OUTPUT_DIR} \ | ||
| trainer_args.fsdp2=true \ | ||
| trainer_args.max_steps=${MAX_STEPS} \ | ||
| trainer_args.fsdp_config.transformer_layer_cls_to_wrap=["Qwen3DecoderLayer"] \ | ||
| trainer_args.fsdp_config.reshard_after_forward=false \ | ||
| trainer_args.sp_ulysses_degree=1 \ | ||
| trainer_args.use_liger_kernel=true \ | ||
| trainer_args.use_rmpad=true \ | ||
| trainer_args.dataloader_num_workers=0 \ | ||
| trainer_args.dataloader_prefetch_factor=null \ | ||
| trainer_args.print_batch_input_steps=5 \ | ||
| trainer_args.bf16=true \ | ||
| trainer_args.lr_scheduler_type=cosine \ | ||
| trainer_args.logging_steps=1 \ | ||
| trainer_args.group_by_length=false \ | ||
| trainer_args.bf16=true | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,305 @@ | ||
| from types import SimpleNamespace | ||
| from typing import List, Optional, Sequence, Tuple | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from PIL import Image | ||
| from transformers import AutoProcessor, AutoTokenizer | ||
|
|
||
| from lmms_engine.mapping_func import register_processor | ||
|
|
||
| from .config import ProcessorConfig | ||
|
|
||
|
|
||
| @register_processor("nanovlm") | ||
| class NanovlmDataProcessor: | ||
| def __init__(self, config: ProcessorConfig) -> None: | ||
| self.config = config | ||
|
|
||
| def build(self): | ||
| self._tokenizer = AutoTokenizer.from_pretrained(self.config.processor_name) | ||
|
|
||
| image_processor_name = self.config.extra_kwargs.get( | ||
| "image_processor_name", "google/siglip2-base-patch16-naflex" | ||
| ) | ||
| self.image_processor = AutoProcessor.from_pretrained(image_processor_name) | ||
|
|
||
| self.image_token = self.config.extra_kwargs.get("image_token", "<|image_pad|>") | ||
| self.video_token = self.config.extra_kwargs.get("video_token", "<|video_pad|>") | ||
|
|
||
| for name, token in (("image_token", self.image_token), ("video_token", self.video_token)): | ||
| if token not in self._tokenizer.get_vocab(): | ||
| raise ValueError(f"{name} {token} not found in tokenizer vocab. Please use a Qwen3 token.") | ||
|
|
||
| self.processor = SimpleNamespace( | ||
| tokenizer=self._tokenizer, | ||
| image_token=self.image_token, | ||
| video_token=self.video_token, | ||
| batch_decode=self._tokenizer.batch_decode, | ||
| ) | ||
|
|
||
| def save_pretrained(self, output_dir: str): | ||
| self._tokenizer.save_pretrained(output_dir) | ||
| self.image_processor.save_pretrained(output_dir) | ||
|
|
||
| def process( | ||
| self, | ||
| images: Optional[List[Image.Image]], | ||
| hf_messages, | ||
| videos=None, | ||
| system_message: str = "You are a helpful assistant", | ||
| add_system_prompt=True, | ||
| add_generation_prompt=False, | ||
| **kwargs, | ||
| ): | ||
| flat_images, num_image_tokens, num_video_tokens = self._prepare_visual_inputs( | ||
| images=images, | ||
| videos=videos, | ||
| hf_messages=hf_messages, | ||
| ) | ||
|
|
||
| if flat_images is not None and len(flat_images) > 0: | ||
| image_inputs = self.image_processor(images=flat_images, return_tensors="pt") | ||
| else: | ||
| image_inputs = {} | ||
| num_image_tokens = None | ||
| num_video_tokens = None | ||
|
|
||
| inputs = self.get_qwen_template_labels( | ||
| hf_messages, | ||
| num_image_tokens, | ||
| num_video_tokens, | ||
| system_message=system_message, | ||
| add_system_prompt=add_system_prompt, | ||
| add_generation_prompt=add_generation_prompt, | ||
| ) | ||
| for key in ("pixel_values", "pixel_attention_mask", "spatial_shapes"): | ||
| if key in image_inputs: | ||
| inputs[key] = image_inputs[key] | ||
| return inputs | ||
|
|
||
| def get_qwen_template_labels( | ||
| self, | ||
| hf_messages, | ||
| num_image_tokens: Optional[List[int]], | ||
| num_video_tokens: Optional[List[int]], | ||
| system_message: str = "You are a helpful assistant", | ||
| add_system_prompt: bool = True, | ||
| add_generation_prompt: bool = False, | ||
| ): | ||
| template_messages = self._normalize_messages_for_template(hf_messages) | ||
| special_tokens = list(self._tokenizer.additional_special_tokens) | ||
| special_tokens.extend(["<|im_start|>", "<|im_end|>"]) | ||
| unmask_tokens_idx = [self._tokenizer.convert_tokens_to_ids(t) for t in special_tokens] | ||
| input_id, target = [], [] | ||
| image_start_from = 0 | ||
| video_start_from = 0 | ||
| if add_system_prompt and template_messages[0]["role"] != "system": | ||
| input_id += self._apply_chat_template( | ||
| [{"role": "system", "content": system_message}], tokenize=True | ||
| ) | ||
| target += [-100] * len(input_id) | ||
|
|
||
| for message in template_messages: | ||
| role = message["role"] | ||
| encode_id = self._apply_chat_template([message], tokenize=True) | ||
| if self.image_token_id in encode_id and num_image_tokens is not None: | ||
| encode_id, used_images = self._expand_encode_id_image_tokens( | ||
| encode_id, num_image_tokens, image_start_from | ||
| ) | ||
| image_start_from += used_images | ||
| if self.video_token_id in encode_id and num_video_tokens is not None: | ||
| encode_id, used_videos = self._expand_encode_id_video_tokens( | ||
| encode_id, num_video_tokens, video_start_from | ||
| ) | ||
| video_start_from += used_videos | ||
| # Nanovlm only understands image_token_id, map video tokens to image tokens | ||
| encode_id = [self.image_token_id if t == self.video_token_id else t for t in encode_id] | ||
| input_id += encode_id | ||
| if role in ["user", "system"]: | ||
| target += [-100] * len(encode_id) | ||
| else: | ||
| encode_id_copy = list(encode_id) | ||
| encode_id_copy[:3] = [-100] * 3 | ||
| target += encode_id_copy | ||
|
|
||
| if add_generation_prompt: | ||
| generation_tokens = self._tokenizer.encode("<|im_start|>assistant\n") | ||
| input_id += generation_tokens | ||
| target += [-100] * len(generation_tokens) | ||
|
|
||
| assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" | ||
| for idx, encode_id in enumerate(input_id): | ||
| if encode_id in unmask_tokens_idx: | ||
| target[idx] = encode_id | ||
| if encode_id == self.image_token_id: | ||
| target[idx] = -100 | ||
|
|
||
| input_id = torch.tensor(input_id, dtype=torch.long) | ||
| target = torch.tensor(target, dtype=torch.long) | ||
| return dict( | ||
| input_ids=input_id, | ||
| labels=target, | ||
| ) | ||
|
|
||
| def _expand_encode_id_image_tokens( | ||
| self, | ||
| encode_id: List[int], | ||
| image_token_num: List[int], | ||
| start_from: int = 0, | ||
| ): | ||
| image_pos = [i for i, x in enumerate(encode_id) if x == self.image_token_id] | ||
| expanded_encode_id = [] | ||
| prev = 0 | ||
| for idx, pos in enumerate(image_pos): | ||
| expanded_encode_id.extend(encode_id[prev:pos]) | ||
| expanded_encode_id.extend([self.image_token_id] * image_token_num[idx + start_from]) | ||
| prev = pos + 1 | ||
|
|
||
| if idx == len(image_pos) - 1: | ||
| expanded_encode_id.extend(encode_id[prev:]) | ||
|
|
||
| return expanded_encode_id, len(image_pos) | ||
|
|
||
| def _expand_encode_id_video_tokens( | ||
| self, | ||
| encode_id: List[int], | ||
| video_token_num: List[int], | ||
| start_from: int = 0, | ||
| ): | ||
| video_pos = [i for i, x in enumerate(encode_id) if x == self.video_token_id] | ||
| expanded_encode_id = [] | ||
| prev = 0 | ||
| for idx, pos in enumerate(video_pos): | ||
| expanded_encode_id.extend(encode_id[prev:pos]) | ||
| expanded_encode_id.extend([self.video_token_id] * video_token_num[idx + start_from]) | ||
| prev = pos + 1 | ||
|
|
||
| if idx == len(video_pos) - 1: | ||
| expanded_encode_id.extend(encode_id[prev:]) | ||
|
|
||
| return expanded_encode_id, len(video_pos) | ||
|
|
||
| def _apply_chat_template(self, messages, tokenize: bool = False): | ||
| result = self._tokenizer.apply_chat_template(messages, tokenize=tokenize) | ||
| if isinstance(result, list) and result and isinstance(result[0], list): | ||
| return result[0] | ||
| return result | ||
|
|
||
| def _prepare_visual_inputs( | ||
| self, | ||
| images: Optional[List[Image.Image]], | ||
| videos: Optional[Sequence], | ||
| hf_messages, | ||
| ) -> Tuple[Optional[List], Optional[List[int]], Optional[List[int]]]: | ||
| if images is None and videos is None: | ||
| return None, None, None | ||
|
|
||
| image_token_count = self.config.extra_kwargs.get("image_token_count", 256) | ||
| flat_images: List = [] | ||
| num_image_tokens: List[int] = [] | ||
| num_video_tokens: List[int] = [] | ||
|
|
||
| image_idx = 0 | ||
| video_idx = 0 | ||
|
|
||
| for message in hf_messages: | ||
| content = message.get("content", []) | ||
| if not isinstance(content, list): | ||
| continue | ||
| for item in content: | ||
| item_type = item.get("type") | ||
| if item_type == "image": | ||
| if images is None or image_idx >= len(images): | ||
| raise ValueError("Missing image input for <image> placeholder.") | ||
| flat_images.append(self._to_pil_image(images[image_idx])) | ||
| num_image_tokens.append(image_token_count) | ||
| image_idx += 1 | ||
| elif item_type == "video": | ||
| if videos is None or video_idx >= len(videos): | ||
| raise ValueError("Missing video input for <video> placeholder.") | ||
| frames = self._normalize_video_frames(videos[video_idx]) | ||
| flat_images.extend([self._to_pil_image(frame) for frame in frames]) | ||
| num_video_tokens.append(image_token_count * len(frames)) | ||
| video_idx += 1 | ||
|
|
||
| if len(flat_images) == 0: | ||
| return None, None, None | ||
|
|
||
| return flat_images, (num_image_tokens or None), (num_video_tokens or None) | ||
|
|
||
| def _normalize_video_frames(self, video) -> List: | ||
| if isinstance(video, list): | ||
| return video | ||
| if isinstance(video, np.ndarray): | ||
| if video.ndim == 3: | ||
| return [video] | ||
| if video.ndim == 4: | ||
| return [frame for frame in video] | ||
| if torch.is_tensor(video): | ||
| video_np = video.detach().cpu().numpy() | ||
| if video_np.ndim == 3: | ||
| return [video_np] | ||
| if video_np.ndim == 4: | ||
| return [frame for frame in video_np] | ||
| raise ValueError(f"Unsupported video format: {type(video)}") | ||
|
|
||
| def _to_pil_image(self, image: object) -> Image.Image: | ||
| if isinstance(image, Image.Image): | ||
| return image.convert("RGB") | ||
| if torch.is_tensor(image): | ||
| image = image.detach().cpu().numpy() | ||
| if isinstance(image, np.ndarray): | ||
| arr = image | ||
| if arr.ndim == 2: | ||
| arr = arr[:, :, None] | ||
| if arr.ndim != 3: | ||
| raise ValueError(f"Unsupported image shape: {arr.shape}") | ||
| # If channel-first, transpose to HWC | ||
| if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): | ||
| arr = np.transpose(arr, (1, 2, 0)) | ||
| if arr.dtype != np.uint8: | ||
| max_val = float(arr.max()) if arr.size > 0 else 1.0 | ||
| if max_val <= 1.0: | ||
| arr = (arr * 255.0).clip(0, 255).astype(np.uint8) | ||
| else: | ||
| arr = arr.clip(0, 255).astype(np.uint8) | ||
| pil = Image.fromarray(arr) | ||
| return pil.convert("RGB") | ||
| raise ValueError(f"Unsupported image type: {type(image)}") | ||
|
|
||
| def _normalize_messages_for_template(self, hf_messages): | ||
| normalized = [] | ||
| for message in hf_messages: | ||
| content = message.get("content") | ||
| if isinstance(content, list): | ||
| parts = [] | ||
| for item in content: | ||
| if not isinstance(item, dict): | ||
| parts.append(str(item)) | ||
| continue | ||
| item_type = item.get("type") | ||
| if item_type in ["image", "image_url"] or "image" in item: | ||
| parts.append("<|vision_start|><|image_pad|><|vision_end|>\n") | ||
| elif item_type in ["video", "video_url"] or "video" in item: | ||
| parts.append("<|vision_start|><|video_pad|><|vision_end|>\n") | ||
| elif item_type in ["audio", "audio_url"] or "audio" in item: | ||
| parts.append("<|AUDIO|>\n") | ||
| elif "text" in item: | ||
| parts.append(item["text"]) | ||
| normalized.append({"role": message["role"], "content": "".join(parts)}) | ||
| else: | ||
| normalized.append(message) | ||
| return normalized | ||
|
||
|
|
||
| @property | ||
| def image_token_id(self): | ||
| return self._tokenizer.convert_tokens_to_ids(self.image_token) | ||
|
|
||
| @property | ||
| def video_token_id(self): | ||
| return self._tokenizer.convert_tokens_to_ids(self.video_token) | ||
|
|
||
| @property | ||
| def tokenizer(self): | ||
| return self._tokenizer | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later change to public path (in the repo) or hf path would be better