|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import os |
| 8 | +import sys |
| 9 | +import tempfile |
| 10 | +import unittest |
| 11 | +from unittest.mock import MagicMock, patch |
| 12 | + |
| 13 | +from executorch.examples.models.llama.config.llm_config import LlmConfig |
| 14 | +from executorch.extension.llm.export.export_llm import main, parse_config_arg, pop_config_arg |
| 15 | + |
| 16 | + |
| 17 | +class TestExportLlm(unittest.TestCase): |
| 18 | + def test_parse_config_arg_with_config(self) -> None: |
| 19 | + """Test parse_config_arg when --config is provided.""" |
| 20 | + # Mock sys.argv to include --config |
| 21 | + test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"] |
| 22 | + with patch.object(sys, "argv", test_argv): |
| 23 | + config_path, remaining = parse_config_arg() |
| 24 | + self.assertEqual(config_path, "test_config.yaml") |
| 25 | + self.assertEqual(remaining, ["extra", "args"]) |
| 26 | + |
| 27 | + def test_parse_config_arg_without_config(self) -> None: |
| 28 | + """Test parse_config_arg when --config is not provided.""" |
| 29 | + test_argv = ["script.py", "debug.verbose=True"] |
| 30 | + with patch.object(sys, "argv", test_argv): |
| 31 | + config_path, remaining = parse_config_arg() |
| 32 | + self.assertIsNone(config_path) |
| 33 | + self.assertEqual(remaining, ["debug.verbose=True"]) |
| 34 | + |
| 35 | + def test_pop_config_arg(self) -> None: |
| 36 | + """Test pop_config_arg removes --config and its value from sys.argv.""" |
| 37 | + test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"] |
| 38 | + with patch.object(sys, "argv", test_argv): |
| 39 | + config_path = pop_config_arg() |
| 40 | + self.assertEqual(config_path, "test_config.yaml") |
| 41 | + self.assertEqual(sys.argv, ["script.py", "other", "args"]) |
| 42 | + |
| 43 | + @patch("executorch.extension.llm.export.export_llm.export_llama") |
| 44 | + def test_with_config(self, mock_export_llama: MagicMock) -> None: |
| 45 | + """Test main function with --config file and no hydra args.""" |
| 46 | + # Create a temporary config file |
| 47 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: |
| 48 | + f.write(""" |
| 49 | +base: |
| 50 | + tokenizer_path: /path/to/tokenizer.json |
| 51 | +export: |
| 52 | + max_seq_length: 256 |
| 53 | +""") |
| 54 | + config_file = f.name |
| 55 | + |
| 56 | + try: |
| 57 | + test_argv = ["script.py", "--config", config_file] |
| 58 | + with patch.object(sys, "argv", test_argv): |
| 59 | + main() |
| 60 | + |
| 61 | + # Verify export_llama was called with config |
| 62 | + mock_export_llama.assert_called_once() |
| 63 | + called_config = mock_export_llama.call_args[0][0] |
| 64 | + self.assertEqual(called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json") |
| 65 | + self.assertEqual(called_config["export"]["max_seq_length"], 256) |
| 66 | + finally: |
| 67 | + os.unlink(config_file) |
| 68 | + |
| 69 | + def test_with_cli_args(self) -> None: |
| 70 | + """Test main function with only hydra CLI args.""" |
| 71 | + test_argv = ["script.py", "debug.verbose=True"] |
| 72 | + with patch.object(sys, "argv", test_argv): |
| 73 | + with patch("executorch.extension.llm.export.export_llm.hydra_main") as mock_hydra: |
| 74 | + main() |
| 75 | + mock_hydra.assert_called_once() |
| 76 | + |
| 77 | + def test_config_with_cli_args_error(self) -> None: |
| 78 | + """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" |
| 79 | + # Create a temporary config file |
| 80 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: |
| 81 | + f.write("base:\n checkpoint: /path/to/checkpoint.pth") |
| 82 | + config_file = f.name |
| 83 | + |
| 84 | + try: |
| 85 | + test_argv = ["script.py", "--config", config_file, "debug.verbose=True"] |
| 86 | + with patch.object(sys, "argv", test_argv): |
| 87 | + with self.assertRaises(ValueError) as cm: |
| 88 | + main() |
| 89 | + |
| 90 | + error_msg = str(cm.exception) |
| 91 | + self.assertIn("Cannot specify additional CLI arguments when using --config", error_msg) |
| 92 | + finally: |
| 93 | + os.unlink(config_file) |
| 94 | + |
| 95 | + def test_config_rejects_multiple_cli_args(self) -> None: |
| 96 | + """Test that --config rejects multiple CLI arguments (not just single ones).""" |
| 97 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: |
| 98 | + f.write("export:\n max_seq_length: 128") |
| 99 | + config_file = f.name |
| 100 | + |
| 101 | + try: |
| 102 | + test_argv = ["script.py", "--config", config_file, "debug.verbose=True", "export.output_dir=/tmp"] |
| 103 | + with patch.object(sys, "argv", test_argv): |
| 104 | + with self.assertRaises(ValueError): |
| 105 | + main() |
| 106 | + finally: |
| 107 | + os.unlink(config_file) |
| 108 | + |
| 109 | + |
| 110 | +if __name__ == "__main__": |
| 111 | + unittest.main() |
| 112 | + |
0 commit comments