Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
from transformers.processing_utils import ProcessorMixin
from transformers.utils import is_offline_mode

from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
from executorch.extension.pybindings.portable_lib import (
ExecuTorchModule,
_load_for_executorch,
)
from executorch.kernels import quantized # noqa

from ..exporters import TasksManager
Expand Down Expand Up @@ -460,7 +463,7 @@ def __init__(
if not hasattr(self, "encoder"):
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
if not hasattr(self, "text_decoder"):
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
raise AttributeError("Expected attribute 'text_decoder' not found in the instance.")
metadata = self.decoder.method_names()
if "use_kv_cache" in metadata:
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
Expand Down Expand Up @@ -495,7 +498,10 @@ def forward(
encoder_outputs = self.encoder.forward((input_ids,))[0]
self.stats.on_prompt_eval_end()

result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)
result = (
self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0],
encoder_outputs,
)
self.stats.on_model_execution_end()
return result

Expand Down Expand Up @@ -1022,29 +1028,27 @@ def __init__(
config: "PretrainedConfig",
):
super().__init__(models=models, config=config)
if not hasattr(self, "encoder"):
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
if not hasattr(self, "text_decoder"):
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
metadata = self.decoder.method_names()
if not hasattr(self, "model"):
raise AttributeError("Expected attribute 'model' not found in the instance.")
metadata = self.model.method_names()
if "use_kv_cache" in metadata:
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
self.use_kv_cache = self.model.run_method("use_kv_cache")[0]
if "get_max_seq_len" in metadata:
self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0]
self.max_cache_size = self.model.run_method("get_max_seq_len")[0]
if "get_max_batch_size" in metadata:
self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0]
self.max_batch_size = self.model.run_method("get_max_batch_size")[0]
if "get_dtype" in metadata:
self.dtype = self.decoder.run_method("get_dtype")[0]
self.dtype = self.model.run_method("get_dtype")[0]
if "get_bos_id" in metadata:
self.bos_token_id = self.decoder.run_method("get_bos_id")[0]
self.bos_token_id = self.model.run_method("get_bos_id")[0]
if "get_eos_id" in metadata:
self.eos_token_id = self.decoder.run_method("get_eos_id")[0]
self.eos_token_id = self.model.run_method("get_eos_id")[0]
if "get_vocab_size" in metadata:
self.vocab_size = self.decoder.run_method("get_vocab_size")[0]
self.vocab_size = self.model.run_method("get_vocab_size")[0]
if "max_hidden_seq_length" in metadata:
self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0]
self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0]
if "decoder_start_token_id" in metadata:
self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0]
self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0]

def forward(
self,
Expand All @@ -1056,10 +1060,13 @@ def forward(
is_first_prediction = encoder_outputs is None
self.stats.on_model_execution_start()
if is_first_prediction:
encoder_outputs = self.encoder.forward((input_features,))[0]
encoder_outputs = self.model.run_method("encoder", (input_features,))[0]
self.stats.on_prompt_eval_end()

result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)
result = (
self.model.run_method("text_decoder", (decoder_input_ids, encoder_outputs, cache_position))[0],
encoder_outputs,
)
self.stats.on_model_execution_end()
return result

Expand Down
37 changes: 26 additions & 11 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
from transformers import (
AutoConfig,
AutoProcessor,
DynamicCache,
EncoderDecoderCache,
PreTrainedModel,
StaticCache,
T5ForConditionalGeneration,
WhisperForConditionalGeneration,
)
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
sdpa_mask_without_vmap,
)
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

Expand All @@ -50,7 +55,10 @@ def prepare_export_inputs(self):
{
"role": "user",
"content": [
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
{
"type": "image",
"url": "https://llava-vl.github.io/static/images/view.jpg",
},
],
},
]
Expand Down Expand Up @@ -330,7 +338,10 @@ def export(
mutated_gm,
args=(),
# For the ET runner, it's important to have cache position as the 2nd arg.
kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position},
kwargs={
"inputs_embeds": inputs_embeds,
"cache_position": cache_position,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
Expand Down Expand Up @@ -390,7 +401,12 @@ class CausalLMExportableModule(torch.nn.Module):
"""

def __init__(
self, model, max_seq_len=2048, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False
self,
model,
max_seq_len=2048,
use_custom_kv_cache=False,
use_custom_sdpa=False,
disable_dynamic_shapes=False,
):
super().__init__()
self.model = model
Expand Down Expand Up @@ -487,7 +503,10 @@ def export(

with torch.no_grad():
exported_program = exportable_module.export(
input_ids=input_ids, cache_position=cache_position, dynamic_shapes=dynamic_shapes, strict=strict
input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=strict,
)
# Apply RemoveTransposes pass to remove
# any back-to-back transpose ops that are not needed
Expand Down Expand Up @@ -643,18 +662,14 @@ def __init__(self, model, max_static_cache_length, batch_size):
device="cpu",
dtype=torch.float32,
)

# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())

def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
Expand Down
14 changes: 7 additions & 7 deletions tests/models/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def test_whisper_export_to_executorch(self):
shell=True,
check=True,
)
self.assertTrue(os.path.exists(f"{tempdir}/executorch/encoder.pte"))
self.assertTrue(os.path.exists(f"{tempdir}/executorch/decoder.pte"))
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))
model = ExecuTorchModelForSpeechSeq2Seq.from_pretrained(f"{tempdir}/executorch")
self._test_whisper_transcription(model_id, model)

Expand All @@ -59,16 +58,17 @@ def _test_whisper_transcription(self, model_id: str, model: ExecuTorchModelForSp
processor = AutoProcessor.from_pretrained(model_id)

self.assertIsInstance(model, ExecuTorchModelForSpeechSeq2Seq)
self.assertTrue(hasattr(model, "encoder"))
self.assertIsInstance(model.encoder, ExecuTorchModule)
self.assertTrue(hasattr(model, "text_decoder"))
self.assertIsInstance(model.decoder, ExecuTorchModule)
self.assertTrue(hasattr(model, "model"))
self.assertIsInstance(model.model, ExecuTorchModule)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

input_features = processor(
sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"]
sample["array"],
return_tensors="pt",
truncation=False,
sampling_rate=sample["sampling_rate"],
).input_features
# Current implementation of the transcibe method accepts up to 30 seconds of audio, therefore I trim the audio here.
input_features_trimmed = input_features[:, :, :3000].contiguous()
Expand Down