Skip to content
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
from transformers import (
AutoConfig,
AutoProcessor,
DynamicCache,
EncoderDecoderCache,
PreTrainedModel,
StaticCache,
T5ForConditionalGeneration,
WhisperForConditionalGeneration,
)
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
sdpa_mask_without_vmap,
)
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

Expand All @@ -50,7 +55,10 @@ def prepare_export_inputs(self):
{
"role": "user",
"content": [
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
{
"type": "image",
"url": "https://llava-vl.github.io/static/images/view.jpg",
},
],
},
]
Expand Down Expand Up @@ -330,7 +338,10 @@ def export(
mutated_gm,
args=(),
# For the ET runner, it's important to have cache position as the 2nd arg.
kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position},
kwargs={
"inputs_embeds": inputs_embeds,
"cache_position": cache_position,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
Expand Down Expand Up @@ -390,7 +401,12 @@ class CausalLMExportableModule(torch.nn.Module):
"""

def __init__(
self, model, max_seq_len=2048, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False
self,
model,
max_seq_len=2048,
use_custom_kv_cache=False,
use_custom_sdpa=False,
disable_dynamic_shapes=False,
):
super().__init__()
self.model = model
Expand Down Expand Up @@ -487,7 +503,10 @@ def export(

with torch.no_grad():
exported_program = exportable_module.export(
input_ids=input_ids, cache_position=cache_position, dynamic_shapes=dynamic_shapes, strict=strict
input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=strict,
)
# Apply RemoveTransposes pass to remove
# any back-to-back transpose ops that are not needed
Expand Down Expand Up @@ -643,18 +662,14 @@ def __init__(self, model, max_static_cache_length, batch_size):
device="cpu",
dtype=torch.float32,
)

# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())

def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
Expand Down
Loading