Skip to content
4 changes: 1 addition & 3 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ def run(self):
"--qlinear_packing_format can only be used when --device is set to CUDA (e.g., 'cuda', 'cuda:0', etc.)"
)
if not self.args.qlinear or self.args.qlinear != "4w":
raise ValueError(
"--qlinear_packing_format can only be used when --qlinear is set to '4w'"
)
raise ValueError("--qlinear_packing_format can only be used when --qlinear is set to '4w'")
qlinear_encoder_packing_format = getattr(self.args, "qlinear_encoder_packing_format", None)
if qlinear_encoder_packing_format:
if not device or not device.startswith("cuda"):
Expand Down
46 changes: 27 additions & 19 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
from transformers.processing_utils import ProcessorMixin
from transformers.utils import is_offline_mode

from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
from executorch.extension.pybindings.portable_lib import (
ExecuTorchModule,
_load_for_executorch,
)
from executorch.kernels import quantized # noqa

from ..exporters import TasksManager
Expand Down Expand Up @@ -460,7 +463,7 @@ def __init__(
if not hasattr(self, "encoder"):
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
if not hasattr(self, "text_decoder"):
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
raise AttributeError("Expected attribute 'text_decoder' not found in the instance.")
metadata = self.decoder.method_names()
if "use_kv_cache" in metadata:
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
Expand Down Expand Up @@ -495,7 +498,10 @@ def forward(
encoder_outputs = self.encoder.forward((input_ids,))[0]
self.stats.on_prompt_eval_end()

result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)
result = (
self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0],
encoder_outputs,
)
self.stats.on_model_execution_end()
return result

Expand Down Expand Up @@ -1022,29 +1028,27 @@ def __init__(
config: "PretrainedConfig",
):
super().__init__(models=models, config=config)
if not hasattr(self, "encoder"):
raise AttributeError("Expected attribute 'encoder' not found in the instance.")
if not hasattr(self, "text_decoder"):
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
metadata = self.decoder.method_names()
if not hasattr(self, "model"):
raise AttributeError("Expected attribute 'model' not found in the instance.")
metadata = self.model.method_names()
if "use_kv_cache" in metadata:
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
self.use_kv_cache = self.model.run_method("use_kv_cache")[0]
if "get_max_seq_len" in metadata:
self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0]
self.max_cache_size = self.model.run_method("get_max_seq_len")[0]
if "get_max_batch_size" in metadata:
self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0]
self.max_batch_size = self.model.run_method("get_max_batch_size")[0]
if "get_dtype" in metadata:
self.dtype = self.decoder.run_method("get_dtype")[0]
self.dtype = self.model.run_method("get_dtype")[0]
if "get_bos_id" in metadata:
self.bos_token_id = self.decoder.run_method("get_bos_id")[0]
self.bos_token_id = self.model.run_method("get_bos_id")[0]
if "get_eos_id" in metadata:
self.eos_token_id = self.decoder.run_method("get_eos_id")[0]
self.eos_token_id = self.model.run_method("get_eos_id")[0]
if "get_vocab_size" in metadata:
self.vocab_size = self.decoder.run_method("get_vocab_size")[0]
self.vocab_size = self.model.run_method("get_vocab_size")[0]
if "max_hidden_seq_length" in metadata:
self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0]
self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0]
if "decoder_start_token_id" in metadata:
self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0]
self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0]

def forward(
self,
Expand All @@ -1056,10 +1060,13 @@ def forward(
is_first_prediction = encoder_outputs is None
self.stats.on_model_execution_start()
if is_first_prediction:
encoder_outputs = self.encoder.forward((input_features,))[0]
encoder_outputs = self.model.run_method("encoder", (input_features,))[0]
self.stats.on_prompt_eval_end()

result = (self.decoder.forward((decoder_input_ids, encoder_outputs, cache_position))[0], encoder_outputs)
result = (
self.model.run_method("text_decoder", (decoder_input_ids, encoder_outputs, cache_position))[0],
encoder_outputs,
)
self.stats.on_model_execution_end()
return result

Expand Down Expand Up @@ -1117,6 +1124,7 @@ def generate(
if not first_token_generated:
self.stats.on_first_token()
first_token_generated = True

# Get next token
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
generated_ids.append(next_token)
Expand Down
114 changes: 65 additions & 49 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
from transformers import (
AutoConfig,
AutoProcessor,
DynamicCache,
EncoderDecoderCache,
PreTrainedModel,
StaticCache,
T5ForConditionalGeneration,
WhisperForConditionalGeneration,
)
from transformers.generation.configuration_utils import GenerationConfig
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM, sdpa_mask_without_vmap
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
sdpa_mask_without_vmap,
)
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

Expand All @@ -50,7 +54,10 @@ def prepare_export_inputs(self):
{
"role": "user",
"content": [
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
{
"type": "image",
"url": "https://llava-vl.github.io/static/images/view.jpg",
},
],
},
]
Expand Down Expand Up @@ -337,7 +344,10 @@ def export(
mutated_gm,
args=(),
# For the ET runner, it's important to have cache position as the 2nd arg.
kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position},
kwargs={
"inputs_embeds": inputs_embeds,
"cache_position": cache_position,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
Expand Down Expand Up @@ -400,7 +410,12 @@ class CausalLMExportableModule(torch.nn.Module):
"""

def __init__(
self, model, max_seq_len=2048, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False
self,
model,
max_seq_len=2048,
use_custom_kv_cache=False,
use_custom_sdpa=False,
disable_dynamic_shapes=False,
):
super().__init__()
self.model = model
Expand Down Expand Up @@ -497,7 +512,10 @@ def export(

with torch.no_grad():
exported_program = exportable_module.export(
input_ids=input_ids, cache_position=cache_position, dynamic_shapes=dynamic_shapes, strict=strict
input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=strict,
)
# Apply RemoveTransposes pass to remove
# any back-to-back transpose ops that are not needed
Expand Down Expand Up @@ -645,26 +663,38 @@ def __init__(self, model, max_static_cache_length, batch_size):
self.proj_out = model.lm_head
self.config = model.config

# Initialize static cache
self.static_cache = StaticCache(
# Initialize self attention cache
self.self_attention_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=max_static_cache_length,
device="cpu",
device=model.device,
dtype=torch.float32,
)

# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model.device)

# Initialize cross attention cache
self.dynamic_cache = DynamicCache(config=self.config)
self.cache = EncoderDecoderCache(self.self_attention_cache, self.dynamic_cache)

# Register cache buffers to make them exportable.
# Cross attention cache buffer is not registered since it's not actually being used atm.
for i in range(len(self.self_attention_cache)):
self.register_buffer(
f"self_attention_key_cache_{i}", self.self_attention_cache.layers[i].keys, persistent=False
)
self.register_buffer(
f"self_attention_value_cache_{i}", self.self_attention_cache.layers[i].values, persistent=False
)

def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=self.static_cache,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
Expand All @@ -679,26 +709,18 @@ def __init__(
self,
model: PreTrainedModel,
batch_size=1,
max_hidden_seq_length=4096,
cache_implementation="static",
max_cache_length=1024,
max_seq_len=1024,
max_hidden_seq_len=4096,
):
super().__init__()

self.full_model = model
self.model = model
self.encoder = model.get_encoder()
self.config = model.config
self.max_hidden_seq_length = max_hidden_seq_length
self.generation_config = GenerationConfig(
use_cache=True,
max_length=max_cache_length,
cache_implementation=cache_implementation,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_length,
},
)
if isinstance(self.full_model, WhisperForConditionalGeneration):
self.max_hidden_seq_len = max_hidden_seq_len
self.batch_size = batch_size
self.max_seq_len = max_seq_len
if isinstance(self.model, WhisperForConditionalGeneration):
self._processor = AutoProcessor.from_pretrained(model.config._name_or_path)
self._expected_encoder_input_shape = torch.Size(
(
Expand All @@ -707,33 +729,27 @@ def __init__(
self._processor.feature_extractor.nb_max_frames,
)
)
additional_configs = {}
additional_configs["max_hidden_seq_length"] = max_hidden_seq_length
# Metadata to be recorded in the pte model file
self.metadata = save_config_to_constant_methods(
self.config,
self.generation_config,
**additional_configs,
)
self.metadata = save_config_to_constant_methods(self.config, get_max_seq_len=max_seq_len)
self.exported_encoder = None
self.exported_decoder = None

def _export_encoder(self, encoder_input_ids):
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()

# Define dynamic sequence length for encoder
if isinstance(self.full_model, WhisperForConditionalGeneration):
if isinstance(self.model, WhisperForConditionalGeneration):
assert (
encoder_input_ids.shape == self._expected_encoder_input_shape
), f"""This version of Whisper only accepts encoder input of shape {self._expected_encoder_input_shape}, passed shape: {encoder_input_ids.shape}.
For more infromation, please refer to the Whisper preprocessor config."""
dynamic_shapes = None
elif isinstance(self.full_model, T5ForConditionalGeneration):
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
elif isinstance(self.model, T5ForConditionalGeneration):
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len)
dynamic_shapes = {"input_ids": {1: encoder_seq_len_dim}}
else:
raise ValueError(
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule encoder export."
f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule encoder export."
)

# Export the encoder
Expand All @@ -749,27 +765,27 @@ def _export_encoder(self, encoder_input_ids):
def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
wrapped_decoder = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(
model=self.full_model,
max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
batch_size=self.generation_config.cache_config.get("batch_size"),
model=self.model,
max_static_cache_length=self.max_seq_len,
batch_size=self.batch_size,
)
.to("cpu")
.eval()
)

if isinstance(self.full_model, WhisperForConditionalGeneration):
if isinstance(self.model, WhisperForConditionalGeneration):
dynamic_shapes = None
elif isinstance(self.full_model, T5ForConditionalGeneration):
elif isinstance(self.model, T5ForConditionalGeneration):
# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len)
dynamic_shapes = {
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
}
else:
raise ValueError(
f"Unsupported model type {type(self.full_model)} for Seq2SeqLMExportableModule decoder export."
f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule decoder export."
)

# Export the decoder
Expand All @@ -791,7 +807,7 @@ def export(
cache_position=None,
) -> Dict[str, ExportedProgram]:
if encoder_input_ids is None:
if isinstance(self.full_model, WhisperForConditionalGeneration):
if isinstance(self.model, WhisperForConditionalGeneration):
example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape)
else:
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long)
Expand Down
33 changes: 15 additions & 18 deletions optimum/exporters/executorch/recipes/portable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Union

from torch.export import ExportedProgram
Expand Down Expand Up @@ -58,24 +57,22 @@ def _lower_to_executorch(
exported_programs: Dict[str, ExportedProgram],
metadata=None,
) -> Dict[str, ExecutorchProgram]:
et_progs = {}
# If just one exported program, the method name in the .pte for it should be "forward".
if len(exported_programs) == 1:
exported_programs = {"forward": next(iter(exported_programs.values()))}

for pte_name, exported_program in exported_programs.items():
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
et_progs[pte_name] = to_edge_transform_and_lower(
exported_program,
partitioner=[],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=metadata,
transform_passes=[RemovePaddingIdxEmbeddingPass()],
).to_executorch()
logging.debug(
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
)
return et_progs
et_prog = to_edge_transform_and_lower(
exported_programs,
partitioner=[],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=metadata,
transform_passes=[RemovePaddingIdxEmbeddingPass()],
).to_executorch()
pte_name = "model"
return {pte_name: et_prog}

exported_progs = model.export()

Expand Down
Loading
Loading