|
15 | 15 | from PIL.Image import Image |
16 | 16 | from transformers import ( |
17 | 17 | AutoConfig, |
| 18 | + AutoImageProcessor, |
18 | 19 | GenerationConfig, |
19 | 20 | GenerationMixin, |
20 | 21 | PretrainedConfig, |
|
24 | 25 |
|
25 | 26 | from ...exporters.openvino import main_export |
26 | 27 | from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name |
| 28 | +from ...exporters.openvino.utils import save_config |
27 | 29 | from .. import OVQuantizer |
28 | 30 | from .configuration import OVConfig, OVWeightQuantizationConfig |
29 | 31 | from .modeling_base import OVBaseModel, OVModelPart |
@@ -319,6 +321,13 @@ def compile(self): |
319 | 321 | if part_model is not None: |
320 | 322 | part_model._compile() |
321 | 323 |
|
| 324 | + def _save_config(self, save_directory): |
| 325 | + """ |
| 326 | + Saves a model configuration into a directory, so that it can be re-loaded using the |
| 327 | + [`from_pretrained`] class method. |
| 328 | + """ |
| 329 | + save_config(self.config, save_directory) |
| 330 | + |
322 | 331 | def _save_pretrained(self, save_directory: Union[str, Path]): |
323 | 332 | """ |
324 | 333 | Saves the model to the OpenVINO IR format so that it can be re-loaded using the |
@@ -728,9 +737,9 @@ def can_generate(self): |
728 | 737 | @staticmethod |
729 | 738 | @abstractmethod |
730 | 739 | def preprocess_inputs( |
731 | | - processor, |
732 | 740 | text: str, |
733 | 741 | image: Optional[Image] = None, |
| 742 | + processor: Optional[AutoImageProcessor] = None, |
734 | 743 | tokenizer: Optional[PreTrainedTokenizer] = None, |
735 | 744 | ): |
736 | 745 | """ |
@@ -902,15 +911,23 @@ def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values): |
902 | 911 |
|
903 | 912 | @staticmethod |
904 | 913 | def preprocess_inputs( |
905 | | - processor, |
906 | 914 | text: str, |
907 | 915 | image: Optional[Image] = None, |
| 916 | + processor: Optional[AutoImageProcessor] = None, |
908 | 917 | tokenizer: Optional[PreTrainedTokenizer] = None, |
909 | 918 | ): |
910 | | - if image is None: |
911 | | - raise ValueError("Image is required.") |
912 | | - chat_template = [{"role": "user", "content": [{"type": "text", "text": text}, {"type": "image"}]}] |
913 | | - prompt = processor.apply_chat_template(chat_template, add_generation_prompt=True) |
| 919 | + if processor is None: |
| 920 | + raise ValueError("Processor is required.") |
| 921 | + if getattr(processor, "chat_template", None) is not None: |
| 922 | + chat_prompt = [{"role": "user", "content": [{"type": "text", "text": text}]}] |
| 923 | + if image is not None: |
| 924 | + chat_prompt[0]["content"].append({"type": "image"}) |
| 925 | + prompt = processor.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False) |
| 926 | + else: |
| 927 | + if image is not None and "<image>" not in text: |
| 928 | + prompt = "<image>\n" + text |
| 929 | + else: |
| 930 | + prompt = text |
914 | 931 | inputs = processor(images=image, text=prompt, return_tensors="pt") |
915 | 932 | return inputs |
916 | 933 |
|
@@ -1209,6 +1226,159 @@ def merge_vision_text_embeddings( |
1209 | 1226 | input_embeds = input_embeds.reshape(B, N, C) |
1210 | 1227 | return input_embeds, attention_mask, position_ids |
1211 | 1228 |
|
| 1229 | + def preprocess_inputs( |
| 1230 | + self, |
| 1231 | + text: str, |
| 1232 | + image: Optional[Image] = None, |
| 1233 | + processor: Optional[AutoImageProcessor] = None, |
| 1234 | + tokenizer: Optional[PreTrainedTokenizer] = None, |
| 1235 | + ): |
| 1236 | + if tokenizer is None: |
| 1237 | + raise ValueError("Tokenizer is required.") |
| 1238 | + import torchvision.transforms as T |
| 1239 | + from torchvision.transforms.functional import InterpolationMode |
| 1240 | + |
| 1241 | + IMG_START_TOKEN = "<img>" |
| 1242 | + IMG_END_TOKEN = "</img>" |
| 1243 | + IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" |
| 1244 | + |
| 1245 | + IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| 1246 | + IMAGENET_STD = (0.229, 0.224, 0.225) |
| 1247 | + |
| 1248 | + def build_transform(input_size): |
| 1249 | + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
| 1250 | + transform = T.Compose( |
| 1251 | + [ |
| 1252 | + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), |
| 1253 | + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
| 1254 | + T.ToTensor(), |
| 1255 | + T.Normalize(mean=MEAN, std=STD), |
| 1256 | + ] |
| 1257 | + ) |
| 1258 | + return transform |
| 1259 | + |
| 1260 | + def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| 1261 | + best_ratio_diff = float("inf") |
| 1262 | + best_ratio = (1, 1) |
| 1263 | + area = width * height |
| 1264 | + for ratio in target_ratios: |
| 1265 | + target_aspect_ratio = ratio[0] / ratio[1] |
| 1266 | + ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| 1267 | + if ratio_diff < best_ratio_diff: |
| 1268 | + best_ratio_diff = ratio_diff |
| 1269 | + best_ratio = ratio |
| 1270 | + elif ratio_diff == best_ratio_diff: |
| 1271 | + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| 1272 | + best_ratio = ratio |
| 1273 | + return best_ratio |
| 1274 | + |
| 1275 | + def dynamic_preprocess(image, min_num=1, max_num=12, image_size=28, use_thumbnail=False): |
| 1276 | + orig_width, orig_height = image.size |
| 1277 | + aspect_ratio = orig_width / orig_height |
| 1278 | + |
| 1279 | + # calculate the existing image aspect ratio |
| 1280 | + target_ratios = { |
| 1281 | + (i, j) |
| 1282 | + for n in range(min_num, max_num + 1) |
| 1283 | + for i in range(1, n + 1) |
| 1284 | + for j in range(1, n + 1) |
| 1285 | + if i * j <= max_num and i * j >= min_num |
| 1286 | + } |
| 1287 | + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| 1288 | + |
| 1289 | + # find the closest aspect ratio to the target |
| 1290 | + target_aspect_ratio = find_closest_aspect_ratio( |
| 1291 | + aspect_ratio, target_ratios, orig_width, orig_height, image_size |
| 1292 | + ) |
| 1293 | + |
| 1294 | + # calculate the target width and height |
| 1295 | + target_width = image_size * target_aspect_ratio[0] |
| 1296 | + target_height = image_size * target_aspect_ratio[1] |
| 1297 | + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
| 1298 | + |
| 1299 | + # resize the image |
| 1300 | + resized_img = image.resize((target_width, target_height)) |
| 1301 | + processed_images = [] |
| 1302 | + for i in range(blocks): |
| 1303 | + box = ( |
| 1304 | + (i % (target_width // image_size)) * image_size, |
| 1305 | + (i // (target_width // image_size)) * image_size, |
| 1306 | + ((i % (target_width // image_size)) + 1) * image_size, |
| 1307 | + ((i // (target_width // image_size)) + 1) * image_size, |
| 1308 | + ) |
| 1309 | + # split the image |
| 1310 | + split_img = resized_img.crop(box) |
| 1311 | + processed_images.append(split_img) |
| 1312 | + assert len(processed_images) == blocks |
| 1313 | + if use_thumbnail and len(processed_images) != 1: |
| 1314 | + thumbnail_img = image.resize((image_size, image_size)) |
| 1315 | + processed_images.append(thumbnail_img) |
| 1316 | + return processed_images |
| 1317 | + |
| 1318 | + def load_image(image, input_size=448, max_num=12): |
| 1319 | + transform = build_transform(input_size=input_size) |
| 1320 | + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
| 1321 | + pixel_values = [transform(image) for image in images] |
| 1322 | + pixel_values = torch.stack(pixel_values) |
| 1323 | + return pixel_values |
| 1324 | + |
| 1325 | + if image is not None: |
| 1326 | + if "<image>" not in text: |
| 1327 | + text = "<image>\n" + text |
| 1328 | + pixel_values = load_image(image, input_size=self.config.vision_config.image_size) |
| 1329 | + num_patches = pixel_values.shape[0] |
| 1330 | + num_image_token = int( |
| 1331 | + (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2 |
| 1332 | + * (self.config.downsample_ratio**2) |
| 1333 | + ) |
| 1334 | + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_image_token * num_patches + IMG_END_TOKEN |
| 1335 | + text = text.replace("<image>", image_tokens, 1) |
| 1336 | + text_inputs = tokenizer(text, return_tensors="pt") |
| 1337 | + inputs = dict(text_inputs) |
| 1338 | + inputs.update({"pixel_values": pixel_values}) |
| 1339 | + else: |
| 1340 | + inputs = tokenizer(text, return_tensors="pt") |
| 1341 | + return inputs |
| 1342 | + |
| 1343 | + # internvl has issue with check _get_non_default_parameters, as wrkaraund overide _prepare_generation_config |
| 1344 | + def _prepare_generation_config( |
| 1345 | + self, generation_config: Optional[GenerationConfig], **kwargs: Dict |
| 1346 | + ) -> Tuple[GenerationConfig, Dict]: |
| 1347 | + using_model_generation_config = False |
| 1348 | + if generation_config is None: |
| 1349 | + if ( |
| 1350 | + self.generation_config._from_model_config # 1) |
| 1351 | + and self.generation_config._original_object_hash == hash(self.generation_config) # 2) |
| 1352 | + ): |
| 1353 | + new_generation_config = GenerationConfig.from_model_config(self.config) |
| 1354 | + if new_generation_config != self.generation_config: # 4) |
| 1355 | + warnings.warn( |
| 1356 | + "You have modified the pretrained model configuration to control generation. This is a" |
| 1357 | + " deprecated strategy to control generation and will be removed in v5." |
| 1358 | + " Please use and modify the model generation configuration (see" |
| 1359 | + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", |
| 1360 | + UserWarning, |
| 1361 | + ) |
| 1362 | + self.generation_config = new_generation_config |
| 1363 | + |
| 1364 | + generation_config = self.generation_config |
| 1365 | + using_model_generation_config = True |
| 1366 | + |
| 1367 | + generation_config = copy.deepcopy(generation_config) |
| 1368 | + model_kwargs = generation_config.update(**kwargs) |
| 1369 | + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model |
| 1370 | + if not using_model_generation_config: |
| 1371 | + if generation_config.bos_token_id is None: |
| 1372 | + generation_config.bos_token_id = self.generation_config.bos_token_id |
| 1373 | + if generation_config.eos_token_id is None: |
| 1374 | + generation_config.eos_token_id = self.generation_config.eos_token_id |
| 1375 | + if generation_config.pad_token_id is None: |
| 1376 | + generation_config.pad_token_id = self.generation_config.pad_token_id |
| 1377 | + if generation_config.decoder_start_token_id is None: |
| 1378 | + generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id |
| 1379 | + |
| 1380 | + return generation_config, model_kwargs |
| 1381 | + |
1212 | 1382 |
|
1213 | 1383 | class _OVMiniCPMVForCausalLM(OVModelForVisualCausalLM): |
1214 | 1384 | additional_parts = ["resampler"] |
@@ -1430,14 +1600,22 @@ def merge_vision_text_embeddings( |
1430 | 1600 |
|
1431 | 1601 | @staticmethod |
1432 | 1602 | def preprocess_inputs( |
1433 | | - processor, |
1434 | 1603 | text: str, |
1435 | 1604 | image: Optional[Image] = None, |
| 1605 | + processor: Optional[AutoImageProcessor] = None, |
1436 | 1606 | tokenizer: Optional[PreTrainedTokenizer] = None, |
1437 | 1607 | ): |
1438 | | - if image is None: |
1439 | | - raise ValueError("Image is required.") |
1440 | | - prompt = f"<|im_start|>user\n(<image>./</image>)\n{text}<|im_end|>\n<|im_start|>assistant\n" |
| 1608 | + if processor is None: |
| 1609 | + raise ValueError("Processor is required.") |
| 1610 | + if getattr(processor, "chat_template", None) is not None: |
| 1611 | + messages = [{"role": "user", "content": text if image is None else "(<image>./</image>)\n" + text}] |
| 1612 | + prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 1613 | + else: |
| 1614 | + prompt = ( |
| 1615 | + f"<|im_start|>user\n(<image>./</image>)\n{text}<|im_end|>\n<|im_start|>assistant\n" |
| 1616 | + if image is not None |
| 1617 | + else text |
| 1618 | + ) |
1441 | 1619 | inputs = processor([prompt], [image], return_tensors="pt") |
1442 | 1620 | return inputs |
1443 | 1621 |
|
@@ -1615,17 +1793,24 @@ def get_multimodal_embeddings( |
1615 | 1793 |
|
1616 | 1794 | @staticmethod |
1617 | 1795 | def preprocess_inputs( |
1618 | | - processor, |
1619 | 1796 | text: str, |
1620 | 1797 | image: Optional[Image] = None, |
| 1798 | + processor: Optional[AutoImageProcessor] = None, |
1621 | 1799 | tokenizer: Optional[PreTrainedTokenizer] = None, |
1622 | 1800 | ): |
1623 | 1801 | if tokenizer is None: |
1624 | 1802 | raise ValueError("Tokenizer is required.") |
1625 | | - messages = [{"role": "user", "content": f"<image>\n{text}"}] |
1626 | | - text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
1627 | | - text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")] |
1628 | | - input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) |
| 1803 | + if image is not None and processor is None: |
| 1804 | + raise ValueError("Processor is required.") |
| 1805 | + text_content = f"<image>\n{text}" if image is not None else text |
| 1806 | + messages = [{"role": "user", "content": text_content}] |
| 1807 | + if tokenizer.chat_template is not None: |
| 1808 | + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 1809 | + if image is not None: |
| 1810 | + text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")] |
| 1811 | + input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) |
| 1812 | + else: |
| 1813 | + input_ids = tokenizer(text, return_tensors="pt").input_ids |
1629 | 1814 | attention_mask = torch.ones_like(input_ids, dtype=torch.int64) |
1630 | 1815 | result = {"input_ids": input_ids, "attention_mask": attention_mask} |
1631 | 1816 | if image is not None: |
|
0 commit comments