Skip to content

Commit 0032869

Browse files
jackzhxnghinriksnaer
authored andcommitted
Fix test_llm_config (pytorch#11977)
Differential Revision: D77321057
1 parent 6fed744 commit 0032869

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

examples/models/llama/config/test_llm_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_local_global_attention_without_kv(self):
4141

4242
def test_invalid_export_config_context_length(self):
4343
with self.assertRaises(ValueError):
44-
ExportConfig(max_seq_length=128, max_context_length=256)
44+
ExportConfig(max_seq_length=256, max_context_length=128)
4545

4646
def test_invalid_qmode(self):
4747
with self.assertRaises(ValueError):
@@ -84,8 +84,8 @@ def test_valid_llm_config(self):
8484
local_global_attention="[16, 32]",
8585
),
8686
export=ExportConfig(
87-
max_seq_length=256,
88-
max_context_length=128,
87+
max_seq_length=128,
88+
max_context_length=256,
8989
output_dir="/tmp/export",
9090
output_name="model.pte",
9191
),
@@ -94,7 +94,7 @@ def test_valid_llm_config(self):
9494
backend=BackendConfig(
9595
xnnpack=XNNPackConfig(enabled=False),
9696
coreml=CoreMLConfig(
97-
enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL
97+
enabled=True, ios=17, compute_units=CoreMLComputeUnit.cpu_only
9898
),
9999
),
100100
)

extension/llm/export/test/test_export_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
6767
model:
6868
dtype_override: fp16
6969
export:
70-
max_seq_length: 256
70+
max_seq_length: 128
7171
quantization:
7272
pt2e_quantize: xnnpack_dynamic
7373
use_spin_quant: cuda
@@ -93,7 +93,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
9393
self.assertEqual(called_config.base.model_class, "llama2")
9494
self.assertEqual(called_config.base.preq_mode.value, "8da4w")
9595
self.assertEqual(called_config.model.dtype_override.value, "fp16")
96-
self.assertEqual(called_config.export.max_seq_length, 256)
96+
self.assertEqual(called_config.export.max_seq_length, 128)
9797
self.assertEqual(
9898
called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic"
9999
)

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ addopts =
1818
--ignore=devtools/visualization/visualization_utils_test.py
1919
# examples
2020
examples/models/llama/tests
21+
examples/models/llama/config
2122
examples/models/llama3_2_vision/preprocess
2223
examples/models/llama3_2_vision/vision_encoder/test
2324
examples/models/llama3_2_vision/text_decoder/test

0 commit comments

Comments
 (0)