diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 4f799f8..7f23a31 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -424,7 +424,10 @@ def __init__( self.use_custom_sdpa = use_custom_sdpa self.disable_dynamic_shapes = disable_dynamic_shapes self.metadata = save_config_to_constant_methods( - model.config, model.generation_config, get_max_seq_len=max_seq_len, enable_dynamic_shape=not self.disable_dynamic_shapes + model.config, + model.generation_config, + get_max_seq_len=max_seq_len, + enable_dynamic_shape=not self.disable_dynamic_shapes, ) logging.info(f"Metadata to be recorded in PTE: {self.metadata}") @@ -669,11 +672,11 @@ def __init__(self, model, max_static_cache_length, batch_size): max_batch_size=batch_size, max_cache_len=max_static_cache_length, device=model.device, - dtype=torch.float32, + dtype=model.dtype, ) head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) - self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model.device) + self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, model.dtype, model.device) # Initialize cross attention cache self.dynamic_cache = DynamicCache(config=self.config) @@ -735,7 +738,7 @@ def __init__( self.exported_decoder = None def _export_encoder(self, encoder_input_ids): - wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to("cpu").eval() + wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.model.device).eval() # Define dynamic sequence length for encoder if isinstance(self.model, WhisperForConditionalGeneration): @@ -769,7 +772,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi max_static_cache_length=self.max_seq_len, batch_size=self.batch_size, ) - .to("cpu") + .to(self.model.device) .eval() ) @@ -808,9 +811,11 @@ def export( ) -> Dict[str, ExportedProgram]: if encoder_input_ids is None: if isinstance(self.model, WhisperForConditionalGeneration): - example_encoder_input_ids = torch.rand(self._expected_encoder_input_shape) + example_encoder_input_ids = torch.rand( + self._expected_encoder_input_shape, device=self.model.device, dtype=self.model.dtype + ) else: - example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long) + example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long, device=self.model.device) else: example_encoder_input_ids = encoder_input_ids @@ -822,9 +827,15 @@ def export( example_encoder_hidden_states = encoder_hidden_states example_decoder_input_ids = ( - decoder_input_ids if decoder_input_ids is not None else torch.tensor([[0]], dtype=torch.long) + decoder_input_ids + if decoder_input_ids is not None + else torch.tensor([[0]], dtype=torch.long, device=self.model.device) + ) + example_cache_position = ( + cache_position + if cache_position is not None + else torch.tensor([0], dtype=torch.long, device=self.model.device) ) - example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) self.exported_decoder = self._export_decoder( example_decoder_input_ids, diff --git a/optimum/exporters/executorch/tasks/asr.py b/optimum/exporters/executorch/tasks/asr.py index ccf1a7a..fdf1298 100644 --- a/optimum/exporters/executorch/tasks/asr.py +++ b/optimum/exporters/executorch/tasks/asr.py @@ -44,12 +44,13 @@ def load_seq2seq_speech_model(model_name_or_path: str, **kwargs) -> Seq2SeqLMExp Seq2SeqLMExportableModule: An instance of `Seq2SeqLMExportableModule` for exporting and lowering to ExecuTorch. """ - device = "cpu" + device = kwargs.get("device", "cpu") batch_size = 1 max_hidden_seq_len = kwargs.get("max_hidden_seq_len", 4096) max_seq_len = kwargs.get("max_seq_len", 1024) + dtype = kwargs.get("dtype", "float32") - full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path).to(device).eval() + full_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, dtype=dtype, device_map=device).eval() return Seq2SeqLMExportableModule( full_model, batch_size=batch_size, diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py index 15d9e13..9f3b38e 100644 --- a/optimum/exporters/executorch/tasks/causal_lm.py +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -85,7 +85,7 @@ def _load_eager_pretrained( eager_model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map=device, - torch_dtype=dtype, + dtype=dtype, config=config, attn_implementation=attn_implementation, generation_config=GenerationConfig( @@ -125,6 +125,8 @@ def _load_eager_pretrained( batch_size, max_length, ) + else: + raise e # Must disable gradient when exporting a model with a prequantized checkpoint, # e.g. "pytorch/Phi-4-mini-instruct-8da4w". diff --git a/tests/models/test_modeling_whisper.py b/tests/models/test_modeling_whisper.py index a784a2c..36b7d90 100644 --- a/tests/models/test_modeling_whisper.py +++ b/tests/models/test_modeling_whisper.py @@ -99,3 +99,36 @@ def test_whisper_transcription(self): ) def test_whisper_transcription_portable(self): self._helper_whisper_transcription(recipe="portable") + + @slow + @pytest.mark.run_slow + def test_whisper_large_v3_turbo_export_bfloat16(self): + """Test exporting whisper-large-v3-turbo with bfloat16 and check file size is ~1.6GB""" + model_id = "openai/whisper-large-v3-turbo" + task = "automatic-speech-recognition" + recipe = "xnnpack" + dtype = "bfloat16" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch --dtype {dtype}", + shell=True, + check=True, + ) + + # Check that model.pte exists + model_path = os.path.join(tempdir, "executorch", "model.pte") + self.assertTrue(os.path.exists(model_path), f"model.pte not found at {model_path}") + + # Check file size is approximately 1.6GB (allow 10% tolerance) + file_size_bytes = os.path.getsize(model_path) + file_size_gb = file_size_bytes / (1024**3) + expected_size_gb = 1.6 + tolerance = 0.1 # 10% tolerance + + logging.info(f"model.pte size: {file_size_gb:.2f} GB") + self.assertAlmostEqual( + file_size_gb, + expected_size_gb, + delta=expected_size_gb * tolerance, + msg=f"Expected file size ~{expected_size_gb}GB, but got {file_size_gb:.2f}GB", + )