|
| 1 | +from types import SimpleNamespace |
| 2 | +from typing import List, Optional, Sequence, Tuple |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from PIL import Image |
| 7 | +from transformers import AutoProcessor, AutoTokenizer |
| 8 | + |
| 9 | +from lmms_engine.mapping_func import register_processor |
| 10 | + |
| 11 | +from .config import ProcessorConfig |
| 12 | + |
| 13 | + |
| 14 | +@register_processor("nanovlm") |
| 15 | +class NanovlmDataProcessor: |
| 16 | + def __init__(self, config: ProcessorConfig) -> None: |
| 17 | + self.config = config |
| 18 | + |
| 19 | + def build(self): |
| 20 | + self._tokenizer = AutoTokenizer.from_pretrained(self.config.processor_name) |
| 21 | + |
| 22 | + # Load image processor from the same local/remote checkpoint as tokenizer. |
| 23 | + # `NanoVLM_init` now carries both tokenizer and preprocessor configs. |
| 24 | + loaded_processor = AutoProcessor.from_pretrained(self.config.processor_name) |
| 25 | + self.image_processor = getattr(loaded_processor, "image_processor", loaded_processor) |
| 26 | + |
| 27 | + self.image_token = self.config.extra_kwargs.get("image_token", "<|image_pad|>") |
| 28 | + self.video_token = self.config.extra_kwargs.get("video_token", "<|video_pad|>") |
| 29 | + |
| 30 | + for name, token in (("image_token", self.image_token), ("video_token", self.video_token)): |
| 31 | + if token not in self._tokenizer.get_vocab(): |
| 32 | + raise ValueError(f"{name} {token} not found in tokenizer vocab. Please use a Qwen3 token.") |
| 33 | + |
| 34 | + self.processor = SimpleNamespace( |
| 35 | + tokenizer=self._tokenizer, |
| 36 | + image_token=self.image_token, |
| 37 | + video_token=self.video_token, |
| 38 | + batch_decode=self._tokenizer.batch_decode, |
| 39 | + ) |
| 40 | + |
| 41 | + def save_pretrained(self, output_dir: str): |
| 42 | + self._tokenizer.save_pretrained(output_dir) |
| 43 | + self.image_processor.save_pretrained(output_dir) |
| 44 | + |
| 45 | + def process( |
| 46 | + self, |
| 47 | + images: Optional[List[Image.Image]], |
| 48 | + hf_messages, |
| 49 | + videos=None, |
| 50 | + system_message: str = "You are a helpful assistant", |
| 51 | + add_system_prompt=True, |
| 52 | + add_generation_prompt=False, |
| 53 | + **kwargs, |
| 54 | + ): |
| 55 | + flat_images, num_image_tokens, num_video_tokens = self._prepare_visual_inputs( |
| 56 | + images=images, |
| 57 | + videos=videos, |
| 58 | + hf_messages=hf_messages, |
| 59 | + ) |
| 60 | + |
| 61 | + if flat_images is not None and len(flat_images) > 0: |
| 62 | + image_inputs = self.image_processor(images=flat_images, return_tensors="pt") |
| 63 | + else: |
| 64 | + image_inputs = {} |
| 65 | + num_image_tokens = None |
| 66 | + num_video_tokens = None |
| 67 | + |
| 68 | + inputs = self.get_qwen_template_labels( |
| 69 | + hf_messages, |
| 70 | + num_image_tokens, |
| 71 | + num_video_tokens, |
| 72 | + system_message=system_message, |
| 73 | + add_system_prompt=add_system_prompt, |
| 74 | + add_generation_prompt=add_generation_prompt, |
| 75 | + ) |
| 76 | + for key in ("pixel_values", "pixel_attention_mask", "spatial_shapes"): |
| 77 | + if key in image_inputs: |
| 78 | + inputs[key] = image_inputs[key] |
| 79 | + return inputs |
| 80 | + |
| 81 | + def get_qwen_template_labels( |
| 82 | + self, |
| 83 | + hf_messages, |
| 84 | + num_image_tokens: Optional[List[int]], |
| 85 | + num_video_tokens: Optional[List[int]], |
| 86 | + system_message: str = "You are a helpful assistant", |
| 87 | + add_system_prompt: bool = True, |
| 88 | + add_generation_prompt: bool = False, |
| 89 | + ): |
| 90 | + special_tokens = list(self._tokenizer.additional_special_tokens) |
| 91 | + special_tokens.extend(["<|im_start|>", "<|im_end|>"]) |
| 92 | + unmask_tokens_idx = [self._tokenizer.convert_tokens_to_ids(t) for t in special_tokens] |
| 93 | + input_id, target = [], [] |
| 94 | + image_start_from = 0 |
| 95 | + video_start_from = 0 |
| 96 | + if add_system_prompt and hf_messages[0]["role"] != "system": |
| 97 | + input_id += self._apply_chat_template([{"role": "system", "content": system_message}], tokenize=True) |
| 98 | + target += [-100] * len(input_id) |
| 99 | + |
| 100 | + for message in hf_messages: |
| 101 | + role = message["role"] |
| 102 | + encode_id = self._apply_chat_template([message], tokenize=True) |
| 103 | + if self.image_token_id in encode_id and num_image_tokens is not None: |
| 104 | + encode_id, used_images = self._expand_encode_id_image_tokens( |
| 105 | + encode_id, num_image_tokens, image_start_from |
| 106 | + ) |
| 107 | + image_start_from += used_images |
| 108 | + if self.video_token_id in encode_id and num_video_tokens is not None: |
| 109 | + encode_id, used_videos = self._expand_encode_id_video_tokens( |
| 110 | + encode_id, num_video_tokens, video_start_from |
| 111 | + ) |
| 112 | + video_start_from += used_videos |
| 113 | + # Nanovlm only understands image_token_id, map video tokens to image tokens |
| 114 | + encode_id = [self.image_token_id if t == self.video_token_id else t for t in encode_id] |
| 115 | + input_id += encode_id |
| 116 | + if role in ["user", "system"]: |
| 117 | + target += [-100] * len(encode_id) |
| 118 | + else: |
| 119 | + encode_id_copy = list(encode_id) |
| 120 | + encode_id_copy[:3] = [-100] * 3 |
| 121 | + target += encode_id_copy |
| 122 | + |
| 123 | + if add_generation_prompt: |
| 124 | + generation_tokens = self._tokenizer.encode("<|im_start|>assistant\n") |
| 125 | + input_id += generation_tokens |
| 126 | + target += [-100] * len(generation_tokens) |
| 127 | + |
| 128 | + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
| 129 | + for idx, encode_id in enumerate(input_id): |
| 130 | + if encode_id in unmask_tokens_idx: |
| 131 | + target[idx] = encode_id |
| 132 | + if encode_id == self.image_token_id: |
| 133 | + target[idx] = -100 |
| 134 | + |
| 135 | + input_id = torch.tensor(input_id, dtype=torch.long) |
| 136 | + target = torch.tensor(target, dtype=torch.long) |
| 137 | + return dict( |
| 138 | + input_ids=input_id, |
| 139 | + labels=target, |
| 140 | + ) |
| 141 | + |
| 142 | + def _expand_encode_id_image_tokens( |
| 143 | + self, |
| 144 | + encode_id: List[int], |
| 145 | + image_token_num: List[int], |
| 146 | + start_from: int = 0, |
| 147 | + ): |
| 148 | + image_pos = [i for i, x in enumerate(encode_id) if x == self.image_token_id] |
| 149 | + expanded_encode_id = [] |
| 150 | + prev = 0 |
| 151 | + for idx, pos in enumerate(image_pos): |
| 152 | + expanded_encode_id.extend(encode_id[prev:pos]) |
| 153 | + expanded_encode_id.extend([self.image_token_id] * image_token_num[idx + start_from]) |
| 154 | + prev = pos + 1 |
| 155 | + |
| 156 | + if idx == len(image_pos) - 1: |
| 157 | + expanded_encode_id.extend(encode_id[prev:]) |
| 158 | + |
| 159 | + return expanded_encode_id, len(image_pos) |
| 160 | + |
| 161 | + def _expand_encode_id_video_tokens( |
| 162 | + self, |
| 163 | + encode_id: List[int], |
| 164 | + video_token_num: List[int], |
| 165 | + start_from: int = 0, |
| 166 | + ): |
| 167 | + video_pos = [i for i, x in enumerate(encode_id) if x == self.video_token_id] |
| 168 | + expanded_encode_id = [] |
| 169 | + prev = 0 |
| 170 | + for idx, pos in enumerate(video_pos): |
| 171 | + expanded_encode_id.extend(encode_id[prev:pos]) |
| 172 | + expanded_encode_id.extend([self.video_token_id] * video_token_num[idx + start_from]) |
| 173 | + prev = pos + 1 |
| 174 | + |
| 175 | + if idx == len(video_pos) - 1: |
| 176 | + expanded_encode_id.extend(encode_id[prev:]) |
| 177 | + |
| 178 | + return expanded_encode_id, len(video_pos) |
| 179 | + |
| 180 | + def _apply_chat_template(self, messages, tokenize: bool = False): |
| 181 | + result = self._tokenizer.apply_chat_template(messages, tokenize=tokenize) |
| 182 | + if isinstance(result, list) and result and isinstance(result[0], list): |
| 183 | + return result[0] |
| 184 | + return result |
| 185 | + |
| 186 | + def _prepare_visual_inputs( |
| 187 | + self, |
| 188 | + images: Optional[List[Image.Image]], |
| 189 | + videos: Optional[Sequence], |
| 190 | + hf_messages, |
| 191 | + ) -> Tuple[Optional[List], Optional[List[int]], Optional[List[int]]]: |
| 192 | + if images is None and videos is None: |
| 193 | + return None, None, None |
| 194 | + |
| 195 | + image_token_count = self.config.extra_kwargs.get( |
| 196 | + "image_token_count", |
| 197 | + getattr(self.image_processor, "max_num_patches", 256), |
| 198 | + ) |
| 199 | + flat_images: List = [] |
| 200 | + num_image_tokens: List[int] = [] |
| 201 | + num_video_tokens: List[int] = [] |
| 202 | + |
| 203 | + image_idx = 0 |
| 204 | + video_idx = 0 |
| 205 | + |
| 206 | + for message in hf_messages: |
| 207 | + content = message.get("content", []) |
| 208 | + if not isinstance(content, list): |
| 209 | + continue |
| 210 | + for item in content: |
| 211 | + item_type = item.get("type") |
| 212 | + if item_type == "image": |
| 213 | + if images is None or image_idx >= len(images): |
| 214 | + raise ValueError("Missing image input for <image> placeholder.") |
| 215 | + flat_images.append(self._to_pil_image(images[image_idx])) |
| 216 | + num_image_tokens.append(image_token_count) |
| 217 | + image_idx += 1 |
| 218 | + elif item_type == "video": |
| 219 | + if videos is None or video_idx >= len(videos): |
| 220 | + raise ValueError("Missing video input for <video> placeholder.") |
| 221 | + frames = self._normalize_video_frames(videos[video_idx]) |
| 222 | + flat_images.extend([self._to_pil_image(frame) for frame in frames]) |
| 223 | + num_video_tokens.append(image_token_count * len(frames)) |
| 224 | + video_idx += 1 |
| 225 | + |
| 226 | + if len(flat_images) == 0: |
| 227 | + return None, None, None |
| 228 | + |
| 229 | + return flat_images, (num_image_tokens or None), (num_video_tokens or None) |
| 230 | + |
| 231 | + def _normalize_video_frames(self, video) -> List: |
| 232 | + if isinstance(video, list): |
| 233 | + return video |
| 234 | + if isinstance(video, np.ndarray): |
| 235 | + if video.ndim == 3: |
| 236 | + return [video] |
| 237 | + if video.ndim == 4: |
| 238 | + return [frame for frame in video] |
| 239 | + if torch.is_tensor(video): |
| 240 | + video_np = video.detach().cpu().numpy() |
| 241 | + if video_np.ndim == 3: |
| 242 | + return [video_np] |
| 243 | + if video_np.ndim == 4: |
| 244 | + return [frame for frame in video_np] |
| 245 | + raise ValueError(f"Unsupported video format: {type(video)}") |
| 246 | + |
| 247 | + def _to_pil_image(self, image: object) -> Image.Image: |
| 248 | + if isinstance(image, Image.Image): |
| 249 | + return image.convert("RGB") |
| 250 | + if torch.is_tensor(image): |
| 251 | + image = image.detach().cpu().numpy() |
| 252 | + if isinstance(image, np.ndarray): |
| 253 | + arr = image |
| 254 | + if arr.ndim == 2: |
| 255 | + arr = arr[:, :, None] |
| 256 | + if arr.ndim != 3: |
| 257 | + raise ValueError(f"Unsupported image shape: {arr.shape}") |
| 258 | + # If channel-first, transpose to HWC |
| 259 | + if arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): |
| 260 | + arr = np.transpose(arr, (1, 2, 0)) |
| 261 | + if arr.dtype != np.uint8: |
| 262 | + max_val = float(arr.max()) if arr.size > 0 else 1.0 |
| 263 | + if max_val <= 1.0: |
| 264 | + arr = (arr * 255.0).clip(0, 255).astype(np.uint8) |
| 265 | + else: |
| 266 | + arr = arr.clip(0, 255).astype(np.uint8) |
| 267 | + pil = Image.fromarray(arr) |
| 268 | + return pil.convert("RGB") |
| 269 | + raise ValueError(f"Unsupported image type: {type(image)}") |
| 270 | + |
| 271 | + @property |
| 272 | + def image_token_id(self): |
| 273 | + return self._tokenizer.convert_tokens_to_ids(self.image_token) |
| 274 | + |
| 275 | + @property |
| 276 | + def video_token_id(self): |
| 277 | + return self._tokenizer.convert_tokens_to_ids(self.video_token) |
| 278 | + |
| 279 | + @property |
| 280 | + def tokenizer(self): |
| 281 | + return self._tokenizer |
0 commit comments