-
Notifications
You must be signed in to change notification settings - Fork 25
Support image-text-to-text task #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
c187bf7
c904c44
69d47d6
60b00dd
7f46076
655d241
6d201d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| from huggingface_hub import hf_hub_download | ||
| from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | ||
| from transformers import ( | ||
| AutoConfig, | ||
| AutoModelForCausalLM, | ||
| AutoModelForImageClassification, | ||
| AutoModelForMaskedLM, | ||
|
|
@@ -238,9 +239,12 @@ def _export( | |
| ) -> Dict[str, "ExecuTorchModule"]: | ||
| task = kwargs.pop("task", None) | ||
| if task is not None: | ||
| logger.warning(f"task was provided and set to {task} but not used, will be ignored") | ||
| inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) | ||
| logging.info(f"Inferred task from model class: {inferred_task}") | ||
| logger.warning(f"task was provided and set to {task}") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems ok, I had to do this too but @guangy10 thoughts on this? |
||
| elif hasattr(cls, "task"): | ||
| task = cls.task | ||
| else: | ||
| task = TasksManager.infer_task_from_model(cls.auto_model_class) | ||
| logging.info(f"Inferred task from model class: {task}") | ||
|
|
||
| save_dir = TemporaryDirectory() | ||
| save_dir_path = Path(save_dir.name) | ||
|
|
@@ -249,7 +253,7 @@ def _export( | |
| executorch_progs = main_export( | ||
| model_name_or_path=model_id, | ||
| output_dir=save_dir_path, | ||
| task=inferred_task, | ||
| task=task, | ||
| recipe=recipe, | ||
| config=config, | ||
| subfolder=subfolder, | ||
|
|
@@ -309,6 +313,8 @@ def from_pretrained( | |
| model_dir = os.path.join(cached_model_dir, "snapshots", _revision) | ||
| else: | ||
| model_dir = model_id | ||
| if not config: | ||
| config = AutoConfig.from_pretrained(model_id) | ||
|
|
||
| pte_files = find_files_matching_pattern( | ||
| model_dir, | ||
|
|
@@ -1082,3 +1088,137 @@ def transcribe( | |
| self.stats.on_inference_end() | ||
| self.stats.print_report() | ||
| return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||
|
|
||
|
|
||
| class ExecuTorchModelForMultimodalCausalLM(ExecuTorchModelBase): | ||
| """ | ||
| ExecuTorch model for CausalLM with multimodal capability. | ||
|
|
||
| Although the auto_model_class is `AutoModelForCausalLM` same as `ExecuTorchModelForCausalLM`, this model is specifically designed for | ||
| multimodal-text-to-text tasks. This class provides an interface for loading, running, and generating outputs from a vision-language model | ||
| or a audio-language model optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models compatible | ||
| with ExecuTorch runtime. | ||
|
|
||
| Attributes: | ||
| auto_model_class (`Type`): | ||
| Associated Transformers class, `AutoModelForCausalLM`. | ||
| model (`ExecuTorchModule`): | ||
| The loaded ExecuTorch model. | ||
| """ | ||
|
|
||
| auto_model_class = AutoModelForCausalLM | ||
|
|
||
| task = "multimodal-text-to-text" | ||
|
|
||
| def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"): | ||
| super().__init__(models, config) | ||
| if not hasattr(self, "model"): | ||
| raise AttributeError("Expected attribute 'model' not found in the instance.") | ||
|
|
||
| # Make sure config contains vision_config and text_config, otherwise raise an error | ||
| if not hasattr(config, "vision_config") or not hasattr(config, "text_config"): | ||
| raise ValueError( | ||
| "The configuration must contain 'vision_config' and 'text_config' attributes for image-text-to-text task." | ||
| ) | ||
larryliu0820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| metadata = self.model.method_names() | ||
| logging.debug(f"Load all static methods: {metadata}") | ||
| if "use_kv_cache" in metadata: | ||
| self.use_kv_cache = self.model.run_method("use_kv_cache")[0] | ||
| if "get_max_seq_len" in metadata: | ||
| self.max_cache_size = self.model.run_method("get_max_seq_len")[0] | ||
| if "get_max_batch_size" in metadata: | ||
| self.max_batch_size = self.model.run_method("get_max_batch_size")[0] | ||
| if "get_dtype" in metadata: | ||
| self.dtype = self.model.run_method("get_dtype")[0] | ||
| if "get_bos_id" in metadata: | ||
| self.bos_token_id = self.model.run_method("get_bos_id")[0] | ||
| for key in ("get_eos_id", "get_eos_ids"): | ||
| if key in metadata: | ||
| self.eos_token_ids = self.model.run_method(key) | ||
| break | ||
| if "get_vocab_size" in metadata: | ||
| self.vocab_size = self.model.run_method("get_vocab_size")[0] | ||
| if "use_sdpa_with_kv_cache" in metadata: | ||
| self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0] | ||
|
|
||
| def forward( | ||
| self, | ||
| cache_position: torch.LongTensor, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| pixel_values: Optional[torch.FloatTensor] = None, | ||
larryliu0820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> torch.Tensor: | ||
| """ | ||
| Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. Here we are assuming pixel_values only represent 1 image. | ||
|
|
||
| Args: | ||
| input_ids (`torch.Tensor`): Tensor representing current input token id to the model. | ||
| pixel_values (`torch.Tensor`): Tensor representing image input to the model. | ||
| cache_position (`torch.Tensor`): Tensor representing current input position in the cache. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Logits output from the model. | ||
| """ | ||
| if (input_ids is None) and (pixel_values is None): | ||
| raise ValueError("You must specify at least one of input_ids or pixel_values") | ||
| self.stats.on_model_execution_start() | ||
|
|
||
| inputs_embeds = self.model.run_method("token_embedding", (input_ids,))[0] | ||
|
|
||
| if pixel_values is not None: | ||
| image_features = self.model.run_method("image_encoder", (pixel_values,))[0] | ||
|
|
||
| if input_ids is None: | ||
| special_image_mask = inputs_embeds == self.model.run_method("token_embedding", (torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device),))[0] | ||
| else: | ||
| special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) | ||
| special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) | ||
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) | ||
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @larryliu0820 so not doing this in runtime means we make assumptions on where the image tokens go in the prompt, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah the runner will have to take in a vector of inputs, then prefill sequentially. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find the sequential prefilling a bit strange in the runner mainly because it is assuming format on chat template. that image tokens are coming last. You really need to do masked scatter, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The runner knows nothing about the chat template. It only sees [image, text, image..] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that but where image tokens goes or in future speech tokens go in property of the model's chat template, isnt it? So whether it is managed in the runner or the layer above it doesnt matter, but it would have be accounted for somewhere |
||
|
|
||
| logits = self.model.run_method("text_model", (cache_position, inputs_embeds))[0] | ||
larryliu0820 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.stats.on_model_execution_end() | ||
| return logits | ||
|
|
||
| def generate( | ||
| self, | ||
| tokenizer: "PreTrainedTokenizer", | ||
| input_ids: torch.LongTensor, | ||
| pixel_values: Optional[torch.FloatTensor] = None, | ||
| max_new_tokens: int = 100, | ||
| ): | ||
| # Sanity check | ||
|
|
||
| if max_new_tokens <= 0: | ||
| raise ValueError(f"max_new_tokens must be greater than 0, got {max_new_tokens}.") | ||
| elif max_new_tokens > self.max_cache_size: | ||
| logging.warning( | ||
| f"max_new_tokens={max_new_tokens} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." | ||
| ) | ||
| max_new_tokens = self.max_cache_size | ||
|
|
||
| # Prefill | ||
| logits = self.forward( | ||
| input_ids=input_ids, | ||
| pixel_values=pixel_values, | ||
| cache_position=torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device), | ||
| ) | ||
|
|
||
| tokens = [] | ||
|
|
||
| token = torch.argmax(logits[:, -1, :], dim=-1).item() | ||
| tokens.append(token) | ||
| i = 1 | ||
| while i < max_new_tokens: | ||
| # Generate next token | ||
| logits = self.forward( | ||
| input_ids=torch.tensor([token], dtype=torch.long, device=input_ids.device).unsqueeze(0), | ||
| cache_position=torch.tensor([input_ids.size(1) + i - 1], dtype=torch.long, device=input_ids.device), | ||
| ) | ||
| token = torch.argmax(logits[:, -1, :], dim=-1).item() | ||
| tokens.append(token) | ||
|
|
||
| if token in self.eos_token_ids: | ||
| break | ||
| i += 1 | ||
|
|
||
| return tokenizer.decode(tokens, skip_special_tokens=True) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is giving a weird issue in verifying the e2e workflow using ExportedProgram. I forgot what exactly though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not needed. please undo
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No we definitely need this, otherwise e2e won’t work.