diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 12398da..e877114 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -185,9 +185,7 @@ def run(self): "--qlinear_packing_format can only be used when --device is set to CUDA (e.g., 'cuda', 'cuda:0', etc.)" ) if not self.args.qlinear or self.args.qlinear != "4w": - raise ValueError( - "--qlinear_packing_format can only be used when --qlinear is set to '4w'" - ) + raise ValueError("--qlinear_packing_format can only be used when --qlinear is set to '4w'") qlinear_encoder_packing_format = getattr(self.args, "qlinear_encoder_packing_format", None) if qlinear_encoder_packing_format: if not device or not device.startswith("cuda"): diff --git a/optimum/exporters/executorch/recipes/cuda.py b/optimum/exporters/executorch/recipes/cuda.py index 5aa6973..812e0b6 100644 --- a/optimum/exporters/executorch/recipes/cuda.py +++ b/optimum/exporters/executorch/recipes/cuda.py @@ -64,8 +64,6 @@ def export_to_executorch_with_cuda( For encoder-decoder models or multimodal models, it may generate multiple programs. """ # Import here to avoid version conflicts. - from torch._inductor.decomposition import conv1d_to_conv2d - from executorch.backends.cuda.cuda_backend import CudaBackend from executorch.backends.cuda.cuda_partitioner import CudaPartitioner @@ -84,13 +82,7 @@ def _lower_to_executorch( key: [CudaPartitioner([CudaBackend.generate_method_name_compile_spec(key)])] for key in exported_programs.keys() } - # Add decompositions for triton to generate kernels. - for key, ep in exported_programs.items(): - exported_programs[key] = ep.run_decompositions( - { - aten.conv1d.default: conv1d_to_conv2d, - } - ) + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]): et_prog = to_edge_transform_and_lower( exported_programs, diff --git a/tests/models/test_modeling_gemma3.py b/tests/models/test_modeling_gemma3.py index ff507bb..f6800f8 100644 --- a/tests/models/test_modeling_gemma3.py +++ b/tests/models/test_modeling_gemma3.py @@ -22,11 +22,15 @@ import unittest import pytest +import torch from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoProcessor, AutoTokenizer from transformers.testing_utils import slow -from optimum.executorch import ExecuTorchModelForCausalLM, ExecuTorchModelForMultiModalToText +from optimum.executorch import ( + ExecuTorchModelForCausalLM, + ExecuTorchModelForMultiModalToText, +) from ..utils import check_causal_lm_output_quality, check_multimodal_output_quality @@ -288,7 +292,10 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self): processor = AutoProcessor.from_pretrained(model_id) image_url = "https://llava-vl.github.io/static/images/view.jpg" conversation = [ - {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, { "role": "user", "content": [ @@ -337,3 +344,32 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self): self.assertTrue( check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5) ) + + @slow + @pytest.mark.run_slow + @pytest.mark.skipif(is_linux_ci, reason="OOM") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA backend required") + def test_gemma3_export_to_executorch_in_cuda_recipe(self): + model_id = "google/gemma-3-4b-it" + task = "multimodal-text-to-text" + recipe = "cuda" + output_subdir = "executorch" + + with tempfile.TemporaryDirectory() as tempdir: + out_dir: str = f"{tempdir}/executorch" + subprocess.run( + f"optimum-cli export executorch \ + --model {model_id} \ + --task {task} \ + --recipe {recipe} \ + --output_dir {tempdir}/{output_subdir} \ + --dtype bfloat16 \ + --device cuda \ + --max_seq_len 64", + shell=True, + check=True, + ) + pte_full_path: str = f"{out_dir}/model.pte" + ptd_full_path: str = f"{out_dir}/aoti_cuda_blob.ptd" + self.assertTrue(os.path.exists(pte_full_path)) + self.assertTrue(os.path.exists(ptd_full_path))