Skip to content

Commit 06164b6

Browse files
committed
Fix test
1 parent 4dfd43f commit 06164b6

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

extension/llm/export/test/test_export_llm.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,19 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
8888
mock_export_llama.assert_called_once()
8989
called_config = mock_export_llama.call_args[0][0]
9090
self.assertEqual(
91-
called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json"
91+
called_config.base.tokenizer_path, "/path/to/tokenizer.json"
9292
)
93-
self.assertEqual(called_config["base"]["model_class"], "llama2")
94-
self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w")
95-
self.assertEqual(called_config["model"]["dtype_override"].value, "fp16")
96-
self.assertEqual(called_config["export"]["max_seq_length"], 256)
93+
self.assertEqual(called_config.base.model_class, "llama2")
94+
self.assertEqual(called_config.base.preq_mode.value, "8da4w")
95+
self.assertEqual(called_config.model.dtype_override.value, "fp16")
96+
self.assertEqual(called_config.export.max_seq_length, 256)
9797
self.assertEqual(
98-
called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic"
98+
called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic"
9999
)
100+
self.assertEqual(called_config.quantization.use_spin_quant.value, "cuda")
101+
self.assertEqual(called_config.backend.coreml.quantize.value, "c4w")
100102
self.assertEqual(
101-
called_config["quantization"]["use_spin_quant"].value, "cuda"
102-
)
103-
self.assertEqual(
104-
called_config["backend"]["coreml"]["quantize"].value, "c4w"
105-
)
106-
self.assertEqual(
107-
called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu"
103+
called_config.backend.coreml.compute_units.value, "cpu_and_gpu"
108104
)
109105
finally:
110106
os.unlink(config_file)
@@ -142,13 +138,13 @@ def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None:
142138
mock_export_llama.assert_called_once()
143139
called_config = mock_export_llama.call_args[0][0]
144140
self.assertEqual(
145-
called_config["base"]["model_class"], "stories110m"
141+
called_config.base.model_class, "stories110m"
146142
) # Override from CLI.
147143
self.assertEqual(
148-
called_config["model"]["dtype_override"].value, "fp16"
144+
called_config.model.dtype_override.value, "fp16"
149145
) # From yaml.
150146
self.assertEqual(
151-
called_config["backend"]["xnnpack"]["enabled"],
147+
called_config.backend.xnnpack.enabled,
152148
True, # Override from CLI.
153149
)
154150
finally:

0 commit comments

Comments
 (0)