|
4 | 4 | import os |
5 | 5 | import types |
6 | 6 | from contextlib import contextmanager, nullcontext |
7 | | -from typing import Any, Dict, List, Optional, Union |
| 7 | +from typing import Any, Dict, List, Optional, Tuple, Union |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | import torch.nn as nn |
11 | 11 | from accelerate import init_empty_weights, load_checkpoint_in_model |
12 | 12 | from accelerate.utils import modeling |
13 | 13 | from huggingface_hub import HfApi, snapshot_download |
14 | 14 | from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id |
| 15 | +from PIL import Image |
15 | 16 | from torch._prims_common import DeviceLikeType |
16 | 17 | from transformers import ( |
17 | 18 | AutoConfig, |
18 | 19 | AutoModelForCausalLM, |
19 | 20 | AutoModelForImageTextToText, |
| 21 | + AutoProcessor, |
20 | 22 | AutoTokenizer, |
21 | 23 | PretrainedConfig, |
22 | 24 | ) |
|
27 | 29 | WEIGHTS_NAME, |
28 | 30 | ) |
29 | 31 |
|
30 | | -from ..custom_ops.attention_interface import CacheConfig |
| 32 | +from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback |
31 | 33 | from ..utils._config import deep_merge_dicts |
32 | 34 | from ..utils.logger import ad_logger |
33 | 35 | from .factory import ModelFactory, ModelFactoryRegistry |
@@ -108,10 +110,6 @@ def __init__(self, *args, **kwargs): |
108 | 110 | def autoconfig_from_pretrained(self): |
109 | 111 | return AutoConfig.from_pretrained |
110 | 112 |
|
111 | | - @property |
112 | | - def autotokenizer_from_pretrained(self): |
113 | | - return AutoTokenizer.from_pretrained |
114 | | - |
115 | 113 | # TODO (@lucaslie): Do we ever want to switch to from_pretrained? |
116 | 114 | @property |
117 | 115 | def automodel_from_config(self): |
@@ -200,7 +198,7 @@ def init_tokenizer(self) -> Optional[Any]: |
200 | 198 | """Initialize the tokenizer—either a custom name or the model's default.""" |
201 | 199 | if self.tokenizer is None: |
202 | 200 | return None |
203 | | - return self.autotokenizer_from_pretrained(self.tokenizer, **self.tokenizer_kwargs) |
| 201 | + return AutoTokenizer.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) |
204 | 202 |
|
205 | 203 | @staticmethod |
206 | 204 | def _get_ignore_patterns(repo_id: str, skip_prefetch_weights: bool) -> List[str]: |
@@ -366,3 +364,100 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]: |
366 | 364 | @property |
367 | 365 | def automodel_from_config(self): |
368 | 366 | return AutoModelForImageTextToText.from_config |
| 367 | + |
| 368 | + def init_tokenizer(self) -> Optional[Any]: |
| 369 | + """Initialize the tokenizer—either a custom name or the model's default.""" |
| 370 | + processor = self.init_processor() |
| 371 | + if processor is None: |
| 372 | + return None |
| 373 | + return processor.tokenizer |
| 374 | + |
| 375 | + def init_processor(self) -> Optional[Any]: |
| 376 | + """Initialize the processor for the model.""" |
| 377 | + if self.tokenizer is None: |
| 378 | + return None |
| 379 | + return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) |
| 380 | + |
| 381 | + @staticmethod |
| 382 | + def _simple_forward( |
| 383 | + model: nn.Module, |
| 384 | + input_ids: torch.Tensor, |
| 385 | + position_ids: torch.Tensor, |
| 386 | + pixel_values: torch.Tensor, |
| 387 | + ): |
| 388 | + """A simple forward pass for the model to functionalize the args. |
| 389 | +
|
| 390 | + This follows the standard function signature as expected by factory.py. |
| 391 | + """ |
| 392 | + return type(model).forward( |
| 393 | + model, |
| 394 | + input_ids=input_ids, |
| 395 | + position_ids=position_ids, |
| 396 | + pixel_values=pixel_values, |
| 397 | + ) |
| 398 | + |
| 399 | + def get_example_inputs(self) -> Dict[str, torch.Tensor]: |
| 400 | + """Return a dictionary of example inputs for the model.""" |
| 401 | + |
| 402 | + def _prep_seq(text, img1, img2): |
| 403 | + return [ |
| 404 | + { |
| 405 | + "role": "user", |
| 406 | + "content": [ |
| 407 | + {"type": "image", "image": img1}, |
| 408 | + {"type": "image", "image": img2}, |
| 409 | + {"type": "text", "text": text}, |
| 410 | + ], |
| 411 | + } |
| 412 | + ] |
| 413 | + |
| 414 | + # Create a batch of conversations (batch_size = 2) |
| 415 | + batch_messages = [ |
| 416 | + _prep_seq( |
| 417 | + "Describe what you see in the two images and their differences.", |
| 418 | + Image.new("RGB", (16, 16), color=(128, 128, 128)), |
| 419 | + Image.new("RGB", (16, 16), color=(64, 64, 64)), |
| 420 | + ), |
| 421 | + _prep_seq( |
| 422 | + "What are the main differences between these two images?", |
| 423 | + Image.new("RGB", (16, 16), color=(255, 0, 0)), |
| 424 | + Image.new("RGB", (16, 16), color=(0, 255, 0)), |
| 425 | + ), |
| 426 | + ] |
| 427 | + |
| 428 | + processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) |
| 429 | + inputs = processor.apply_chat_template( |
| 430 | + batch_messages, |
| 431 | + add_generation_prompt=True, |
| 432 | + tokenize=True, |
| 433 | + return_dict=True, |
| 434 | + return_tensors="pt", |
| 435 | + padding=True, |
| 436 | + return_attention_mask=False, |
| 437 | + ) |
| 438 | + |
| 439 | + return { |
| 440 | + "input_ids": inputs["input_ids"], |
| 441 | + "pixel_values": inputs["pixel_values"], |
| 442 | + } |
| 443 | + |
| 444 | + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: |
| 445 | + """Return a dictionary of extra inputs for the model. |
| 446 | +
|
| 447 | + Returns: |
| 448 | + A dictionary of extra inputs for the model where the key corresponds to the argument |
| 449 | + name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). |
| 450 | + The dynamic shape callback is a function that returns the dynamic shape of the extra |
| 451 | + input. |
| 452 | + """ |
| 453 | + |
| 454 | + def _get_dynamic_shape(): |
| 455 | + return { |
| 456 | + # TODO (lucaslie): how to set default values for dynamic shapes? |
| 457 | + 0: Dim("img_batch_size", max=10), |
| 458 | + 2: Dim("img_height", min=32, max=2048), |
| 459 | + 3: Dim("img_width", min=32, max=2048), |
| 460 | + } |
| 461 | + |
| 462 | + none_pixel_values = torch.zeros(0, 3, 336, 336) |
| 463 | + return {"pixel_values": (none_pixel_values, _get_dynamic_shape)} |
0 commit comments