diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index 0b811996..c37146f3 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -36,7 +36,7 @@ Seq2SeqLMExportableModule, ) from ..recipe_registry import register_recipe - +from torchao.utils import unwrap_tensor_subclass @register_recipe("xnnpack") def export_to_executorch_with_xnnpack( @@ -104,6 +104,8 @@ def _lower_to_executorch( ) return {pte_name: et_prog} + model = unwrap_tensor_subclass(model) + print("Unwrapped model: ", model) exported_progs = model.export() if ( diff --git a/tests/models/test_modeling_gemma3.py b/tests/models/test_modeling_gemma3.py index ff507bb6..e6a08449 100644 --- a/tests/models/test_modeling_gemma3.py +++ b/tests/models/test_modeling_gemma3.py @@ -279,11 +279,12 @@ def test_gemma3_270m_text_generation_with_custom_sdpa_8da4w_8we(self): self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens)) - @slow - @pytest.mark.run_slow - @pytest.mark.skipif(is_linux_ci, reason="OOM") + # @slow + # @pytest.mark.run_slow + # @pytest.mark.skipif(is_linux_ci, reason="OOM") def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self): - model_id = "google/gemma-3-4b-it" + # model_id = "google/gemma-3-4b-it" + model_id = "metascroy/gemma-3-4b-it-INT8-INT4" tokenizer = AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) image_url = "https://llava-vl.github.io/static/images/view.jpg" @@ -307,11 +308,12 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self): task="multimodal-text-to-text", use_custom_sdpa=True, use_custom_kv_cache=True, - qlinear="8da4w", - qlinear_group_size=32, + device_map="cpu", + # qlinear="8da4w", + # qlinear_group_size=32, # Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's # XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221. - qembedding_config="8w", + # qembedding_config="8w", ) # Generate