Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions optimum/executorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"ExecuTorchModelForMaskedLM",
"ExecuTorchModelForSeq2SeqLM",
"ExecuTorchModelForSpeechSeq2Seq",
"ExecuTorchModelForMultiModalToText",
],
}

Expand All @@ -34,6 +35,8 @@
ExecuTorchModelForMaskedLM,
ExecuTorchModelForSeq2SeqLM,
ExecuTorchModelForSpeechSeq2Seq,
ExecuTorchModelForImageTextToTextCausalLM,
ExecuTorchModelForMultiModalToText,
)
else:
import sys
Expand Down
28 changes: 18 additions & 10 deletions optimum/executorch/attentions/custom_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def custom_sdpa_with_start_pos_forward(

# Ignore the causal flag from kwargs but use the one in module
kwargs.pop("is_causal", None)
assert module.is_causal, "Current variant supports only causal attention"
# assert module.is_causal, "Current variant supports only causal attention"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this supports non causal case


is_causal = module.is_causal
if kwargs.get("is_sliding", False):
Expand All @@ -56,13 +56,16 @@ def custom_sdpa_with_start_pos_forward(
start_pos = 0
else:
attn_mask = None
# Calculate the input pos from attention mask.
# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, max_seq_len)
first_row_mask = attention_mask[0, :]
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1
if is_causal:
# Calculate the input pos from attention mask.
# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1])
first_row_mask = attention_mask[0, :]
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1
else:
start_pos = 0

output = torch.ops.llama.custom_sdpa(
query,
Expand All @@ -81,14 +84,19 @@ def get_custom_sdpa_for_ring_kv_cache(
exportable_module: torch.nn.Module,
) -> Callable:
# lazy importing to avoid version dependent class definition
from executorch import version
# try:
# from executorch import __version__ as version
# except ImportError:
# # Fallback if version is not available
# version = None

try:
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
CustomRingKVCache,
)
except ImportError:
raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.")
# raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.")
print()

def _custom_sdpa_for_ring_kv_cache(
module: torch.nn.Module,
Expand Down
303 changes: 299 additions & 4 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
PretrainedConfig,
AutoModelForMultimodalTextToText,
PreTrainedTokenizer,
add_start_docstrings,
)
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import is_offline_mode

from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
Expand Down Expand Up @@ -237,9 +238,9 @@ def _export(
**kwargs,
) -> Dict[str, "ExecuTorchModule"]:
task = kwargs.pop("task", None)
if task is not None:
logger.warning(f"task was provided and set to {task} but not used, will be ignored")
inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class)
# if task is not None:
# logger.warning(f"task was provided and set to {task} but not used, will be ignored")
inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) if not task else task
logging.info(f"Inferred task from model class: {inferred_task}")

save_dir = TemporaryDirectory()
Expand Down Expand Up @@ -1098,3 +1099,297 @@ def transcribe(
self.stats.on_inference_end()
self.stats.print_report()
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)


class ExecuTorchModelForImageTextToTextCausalLM(ExecuTorchModelBase):
"""
ExecuTorch model with an image-text-to-text causal language modeling head for inference using the ExecuTorch Runtime.

Although the auto_model_class is `AutoModelForCausalLM` same as `ExecuTorchModelForCausalLM`, this model is specifically designed for
image-text-to-text tasks. This class provides an interface for loading, running, and generating outputs from a vision-language model
optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models
compatible with ExecuTorch runtime.

Attributes:
auto_model_class (`Type`):
Associated Transformers class, `AutoModelForCausalLM`.
model (`ExecuTorchModule`):
The loaded ExecuTorch model.
"""

auto_model_class = AutoModelForCausalLM

def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
super().__init__(models, config)
if not hasattr(self, "model"):
raise AttributeError("Expected attribute 'model' not found in the instance.")

# Make sure config contains vision_config and text_config, otherwise raise an error
if not hasattr(config, "vision_config") or not hasattr(config, "text_config"):
raise ValueError(
"The configuration must contain 'vision_config' and 'text_config' attributes for image-text-to-text task."
)
metadata = self.model.method_names()
logging.debug(f"Load all static methods: {metadata}")
if "use_kv_cache" in metadata:
self.use_kv_cache = self.model.run_method("use_kv_cache")[0]
if "get_max_seq_len" in metadata:
self.max_cache_size = self.model.run_method("get_max_seq_len")[0]
if "get_max_batch_size" in metadata:
self.max_batch_size = self.model.run_method("get_max_batch_size")[0]
if "get_dtype" in metadata:
self.dtype = self.model.run_method("get_dtype")[0]
if "get_bos_id" in metadata:
self.bos_token_id = self.model.run_method("get_bos_id")[0]
for key in ("get_eos_id", "get_eos_ids"):
if key in metadata:
self.eos_token_ids = self.model.run_method(key)
break
if "get_vocab_size" in metadata:
self.vocab_size = self.model.run_method("get_vocab_size")[0]
if "use_sdpa_with_kv_cache" in metadata:
self.use_sdpa_with_kv_cache = self.model.run_method("use_sdpa_with_kv_cache")[0]

def forward(
self,
input_ids: Optional[torch.LongTensor],
pixel_values: Optional[torch.FloatTensor],
inputs_embeds: Optional[torch.FloatTensor],
cache_position: torch.LongTensor,
) -> torch.Tensor:
"""
Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. Here we are assuming pixel_values only represent 1 image.

Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the model.
pixel_values (`torch.Tensor`): Tensor representing image input to the model.
inputs_embeds (`torch.Tensor`): Tensor representing input embeddings to the model.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.

Returns:
torch.Tensor: Logits output from the model.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
self.stats.on_model_execution_start()

if inputs_embeds is None:
inputs_embeds = self.model.run_method("text_embeddings")(input_ids)

if pixel_values is not None:
image_features = self.model.run_method("vision_embeddings")(pixel_values) if pixel_values is not None else None

if input_ids is None:
special_image_mask = inputs_embeds == self.model.run_method("text_embeddings")(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

logits = self.model.run_method("decoder")(
(inputs_embeds, cache_position)
)[0]
self.stats.on_model_execution_end()
return logits

def generate(
self,
tokenizer: "PretrainedTokenizer",
input_ids: torch.LongTensor,
pixel_values: Optional[torch.FloatTensor] = None,
max_new_tokens: int = 100,
):
return 420

# Prefill

class ExecuTorchModelForMultiModalToText(ExecuTorchModelBase):
"""
An ExecuTorch model for inference of multimodal input to text models using the ExecuTorch Runtime.

Attributes:
auto_model_class (`Type`):
Associated Transformers class, `AutoModelForSpeechSeq2Seq`.
model (`ExecuTorchModule`):
The loaded ExecuTorch model.
use_kv_cache (`bool`):
Whether key-value caching is enabled. For performance reasons, the exported model is
optimized to use a static cache.
max_cache_size (`int`):
Maximum sequence length supported by the cache.
max_batch_size (`int`):
Maximum supported batch size.
dtype (`str`):
Data type of the model parameters.
bos_token_id (`int`):
Beginning-of-sequence token ID.
eos_token_id (`int`):
End-of-sequence token ID.
vocab_size (`int`):
Size of the model vocabulary.
"""

auto_model_class = AutoModelForMultimodalTextToText
# auto_model_class = AutoModel

def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
super().__init__(models=models, config=config)
if not hasattr(self, "decoder"):
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
if not hasattr(self, "token_embeddings"):
raise AttributeError("Expected attribute 'token_embeddings' not found in the instance.")
if not hasattr(self, "audio_encoder"):
raise AttributeError("Expected attribute 'audio_encoder' not found in the instance.")
metadata = self.decoder.method_names()
if "use_kv_cache" in metadata:
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
if "get_max_seq_len" in metadata:
self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0]
if "get_max_batch_size" in metadata:
self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0]
if "get_dtype" in metadata:
self.dtype = self.decoder.run_method("get_dtype")[0]
if "get_bos_id" in metadata:
self.bos_token_id = self.decoder.run_method("get_bos_id")[0]
if "get_eos_id" in metadata:
self.eos_token_id = self.decoder.run_method("get_eos_id")[0]
if "get_vocab_size" in metadata:
self.vocab_size = self.decoder.run_method("get_vocab_size")[0]
if "max_hidden_seq_length" in metadata:
self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0]
if "decoder_start_token_id" in metadata:
self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0]

def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
input_features: Optional[torch.Tensor] = None,
):
token_embeddings = self.token_embeddings.forward(input_ids)
if input_features:
token_embeddings = self.audio_encoder.forward(
input_features,
token_embeddings,
input_ids,
)
output = self.decoder.forward(
token_embeddings,
cache_position,
)
return output

def generate(
self,
prompt_tokens: torch.Tensor,
echo: bool = False,
pos_base: int = 0,
max_seq_len: Optional[int] = None,
input_features: Optional[torch.Tensor] = None,
) -> List[int]:
self.device = torch.device("cpu")
if max_seq_len is None:
# Default to max_cache_size if max_seq_len is not specified
max_seq_len = self.max_cache_size
elif max_seq_len > self.max_cache_size:
logging.warning(
f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size."
)
max_seq_len = self.max_cache_size

self.stats.on_sampling_begin()
logits = self.forward(
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
input_features=input_features,
)
self.stats.on_sampling_end()
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
self.stats.on_prompt_eval_end()
first_token_generated = False

generated_tokens = prompt_tokens + [next_token]

while len(generated_tokens) < max_seq_len:
self.stats.on_sampling_begin()
logits = self.forward(
input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0),
cache_position=torch.tensor(
[pos_base + len(generated_tokens) - 1],
dtype=torch.long,
device=self.device,
),
)
self.stats.on_sampling_end()
if not first_token_generated:
self.stats.on_first_token()
first_token_generated = True

next_token = torch.argmax(logits, dim=-1).item()
generated_tokens.append(next_token)

if next_token in self.eos_token_ids:
break

self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens))

return generated_tokens if echo else generated_tokens[len(prompt_tokens) :]

def text_generation(
self,
processor: "ProcessorMixin",
tokenizer: "PreTrainedTokenizer",
input_conversation: List[Dict],
echo: bool = True,
max_seq_len: Optional[int] = None,
):
"""
Perform text generation task for a given prompt using the ExecuTorch model.

Args:
tokenizer (`PreTrainedTokenizer`):
The tokenizer used to encode and decode the prompt and output.
prompt (`str`):
The text prompt to complete.
echo (`bool`, *optional*):
Whether to include prompt tokens in the generated output. Defaults to `True`.
max_seq_len (`int`, *optional*):
Maximum sequence length for the generated output.
Defaults to None and uses the model's `max_cache_size` attribute.
Will be truncated to maximal cache size if larger than `max_cache_size`.
"""
self.tokenizer = tokenizer

# Sanity check
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id:
raise ValueError(
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
)
if not verify_eos_tokens_in_tokenizer(self.eos_token_ids, self.tokenizer):
raise ValueError(
f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}."
)

# Reset stats for a new generation
self.stats.reset()
self.stats.on_inference_start()

inputs = processor.apply_chat_template(input_conversation)
prompt_tokens = self.tokenizer.encode(inputs["input_ids"])
self.stats.on_token_encode_end()
self.stats.set_num_prompt_tokens(len(prompt_tokens))

generated_tokens = self.generate(
prompt_tokens=prompt_tokens,
input_features=inputs["input_features"],
echo=echo,
max_seq_len=max_seq_len,
)

self.stats.on_inference_end()
self.stats.print_report()

return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
Loading
Loading