Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,10 @@ def __init__(
self.use_custom_sdpa = use_custom_sdpa
self.disable_dynamic_shapes = disable_dynamic_shapes
self.metadata = save_config_to_constant_methods(
model.config, model.generation_config, get_max_seq_len=max_seq_len, enable_dynamic_shape=not self.disable_dynamic_shapes
model.config,
model.generation_config,
get_max_seq_len=max_seq_len,
enable_dynamic_shape=not self.disable_dynamic_shapes,
)
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")

Expand Down Expand Up @@ -669,11 +672,11 @@ def __init__(self, model, max_static_cache_length, batch_size):
max_batch_size=batch_size,
max_cache_len=max_static_cache_length,
device=model.device,
dtype=torch.float32,
dtype=model.dtype,
)
head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model.device)
self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, model.dtype, model.device)

# Initialize cross attention cache
self.dynamic_cache = DynamicCache(config=self.config)
Expand Down Expand Up @@ -735,7 +738,7 @@ def __init__(
self.exported_decoder = None

def _export_encoder(self, encoder_input_ids):
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval()
wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.model.device).eval()

# Define dynamic sequence length for encoder
if isinstance(self.model, WhisperForConditionalGeneration):
Expand Down Expand Up @@ -769,7 +772,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
max_static_cache_length=self.max_seq_len,
batch_size=self.batch_size,
)
.to("cpu")
.to(self.model.device)
.eval()
)

Expand Down Expand Up @@ -808,9 +811,11 @@ def export(
) -> Dict[str, ExportedProgram]:
if encoder_input_ids is None:
if isinstance(self.model, WhisperForConditionalGeneration):
example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape)
example_encoder_input_ids = torch.rand(
self._expected_encoder_input_shape, device=self.model.device, dtype=self.model.dtype
)
else:
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long)
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long, device=self.model.device)
else:
example_encoder_input_ids = encoder_input_ids

Expand All @@ -822,9 +827,15 @@ def export(
example_encoder_hidden_states = encoder_hidden_states

example_decoder_input_ids = (
decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long)
decoder_input_ids
if decoder_input_ids is not None
else torch.tensor([[0]], dtype=torch.long, device=self.model.device)
)
example_cache_position = (
cache_position
if cache_position is not None
else torch.tensor([0], dtype=torch.long, device=self.model.device)
)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)

self.exported_decoder = self._export_decoder(
example_decoder_input_ids,
Expand Down
5 changes: 3 additions & 2 deletions optimum/exporters/executorch/tasks/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def load_seq2seq_speech_model(model_name_or_path: str, **kwargs) -> Seq2SeqLMExp
Seq2SeqLMExportableModule:
An instance of `Seq2SeqLMExportableModule` for exporting and lowering to ExecuTorch.
"""
device = "cpu"
device = kwargs.get("device", "cpu")
batch_size = 1
max_hidden_seq_len = kwargs.get("max_hidden_seq_len", 4096)
max_seq_len = kwargs.get("max_seq_len", 1024)
dtype = kwargs.get("dtype", "float32")

full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path).to(device).eval()
full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, dtype=dtype, device_map=device).eval()
return Seq2SeqLMExportableModule(
full_model,
batch_size=batch_size,
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/executorch/tasks/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _load_eager_pretrained(
eager_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device,
torch_dtype=dtype,
dtype=dtype,
config=config,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
Expand Down Expand Up @@ -125,6 +125,8 @@ def _load_eager_pretrained(
batch_size,
max_length,
)
else:
raise e

# Must disable gradient when exporting a model with a prequantized checkpoint,
# e.g. "pytorch/Phi-4-mini-instruct-8da4w".
Expand Down
33 changes: 33 additions & 0 deletions tests/models/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,36 @@ def test_whisper_transcription(self):
)
def test_whisper_transcription_portable(self):
self._helper_whisper_transcription(recipe="portable")

@slow
@pytest.mark.run_slow
def test_whisper_large_v3_turbo_export_bfloat16(self):
"""Test exporting whisper-large-v3-turbo with bfloat16 and check file size is ~1.6GB"""
model_id = "openai/whisper-large-v3-turbo"
task = "automatic-speech-recognition"
recipe = "xnnpack"
dtype = "bfloat16"
with tempfile.TemporaryDirectory() as tempdir:
subprocess.run(
f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch --dtype {dtype}",
shell=True,
check=True,
)

# Check that model.pte exists
model_path = os.path.join(tempdir, "executorch", "model.pte")
self.assertTrue(os.path.exists(model_path), f"model.pte not found at {model_path}")

# Check file size is approximately 1.6GB (allow 10% tolerance)
file_size_bytes = os.path.getsize(model_path)
file_size_gb = file_size_bytes / (1024**3)
expected_size_gb = 1.6
tolerance = 0.1 # 10% tolerance

logging.info(f"model.pte size: {file_size_gb:.2f} GB")
self.assertAlmostEqual(
file_size_gb,
expected_size_gb,
delta=expected_size_gb * tolerance,
msg=f"Expected file size ~{expected_size_gb}GB, but got {file_size_gb:.2f}GB",
)
Loading