diff --git a/examples/models/llama/config/test_llm_config.py b/examples/models/llama/config/test_llm_config.py index 0853e9dbbd8..52b56d71a03 100644 --- a/examples/models/llama/config/test_llm_config.py +++ b/examples/models/llama/config/test_llm_config.py @@ -41,7 +41,7 @@ def test_local_global_attention_without_kv(self): def test_invalid_export_config_context_length(self): with self.assertRaises(ValueError): - ExportConfig(max_seq_length=128, max_context_length=256) + ExportConfig(max_seq_length=256, max_context_length=128) def test_invalid_qmode(self): with self.assertRaises(ValueError): @@ -84,8 +84,8 @@ def test_valid_llm_config(self): local_global_attention="[16, 32]", ), export=ExportConfig( - max_seq_length=256, - max_context_length=128, + max_seq_length=128, + max_context_length=256, output_dir="/tmp/export", output_name="model.pte", ), @@ -94,7 +94,7 @@ def test_valid_llm_config(self): backend=BackendConfig( xnnpack=XNNPackConfig(enabled=False), coreml=CoreMLConfig( - enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL + enabled=True, ios=17, compute_units=CoreMLComputeUnit.cpu_only ), ), ) diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index e6f7160d4af..ab7db1b4e3a 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -67,7 +67,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: model: dtype_override: fp16 export: - max_seq_length: 256 + max_seq_length: 128 quantization: pt2e_quantize: xnnpack_dynamic use_spin_quant: cuda @@ -93,7 +93,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: self.assertEqual(called_config.base.model_class, "llama2") self.assertEqual(called_config.base.preq_mode.value, "8da4w") self.assertEqual(called_config.model.dtype_override.value, "fp16") - self.assertEqual(called_config.export.max_seq_length, 256) + self.assertEqual(called_config.export.max_seq_length, 128) self.assertEqual( called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic" ) diff --git a/pytest.ini b/pytest.ini index 557a307bdf2..e0f8eafb082 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,6 +18,7 @@ addopts = --ignore=devtools/visualization/visualization_utils_test.py # examples examples/models/llama/tests + examples/models/llama/config examples/models/llama3_2_vision/preprocess examples/models/llama3_2_vision/vision_encoder/test examples/models/llama3_2_vision/text_decoder/test