@@ -21,27 +21,37 @@ class TestExportLlm(unittest.TestCase):
2121 def test_parse_config_arg_with_config (self ) -> None :
2222 """Test parse_config_arg when --config is provided."""
2323 # Mock sys.argv to include --config
24- test_argv = ["script .py" , "--config" , "test_config.yaml" , "extra" , "args" ]
24+ test_argv = ["export_llm .py" , "--config" , "test_config.yaml" , "extra" , "args" ]
2525 with patch .object (sys , "argv" , test_argv ):
2626 config_path , remaining = parse_config_arg ()
2727 self .assertEqual (config_path , "test_config.yaml" )
2828 self .assertEqual (remaining , ["extra" , "args" ])
2929
3030 def test_parse_config_arg_without_config (self ) -> None :
3131 """Test parse_config_arg when --config is not provided."""
32- test_argv = ["script .py" , "debug.verbose=True" ]
32+ test_argv = ["export_llm .py" , "debug.verbose=True" ]
3333 with patch .object (sys , "argv" , test_argv ):
3434 config_path , remaining = parse_config_arg ()
3535 self .assertIsNone (config_path )
3636 self .assertEqual (remaining , ["debug.verbose=True" ])
3737
3838 def test_pop_config_arg (self ) -> None :
3939 """Test pop_config_arg removes --config and its value from sys.argv."""
40- test_argv = ["script .py" , "--config" , "test_config.yaml" , "other" , "args" ]
40+ test_argv = ["export_llm .py" , "--config" , "test_config.yaml" , "other" , "args" ]
4141 with patch .object (sys , "argv" , test_argv ):
4242 config_path = pop_config_arg ()
4343 self .assertEqual (config_path , "test_config.yaml" )
44- self .assertEqual (sys .argv , ["script.py" , "other" , "args" ])
44+ self .assertEqual (sys .argv , ["export_llm.py" , "other" , "args" ])
45+
46+ def test_with_cli_args (self ) -> None :
47+ """Test main function with only hydra CLI args."""
48+ test_argv = ["export_llm.py" , "debug.verbose=True" ]
49+ with patch .object (sys , "argv" , test_argv ):
50+ with patch (
51+ "executorch.extension.llm.export.export_llm.hydra_main"
52+ ) as mock_hydra :
53+ main ()
54+ mock_hydra .assert_called_once ()
4555
4656 @patch ("executorch.extension.llm.export.export_llm.export_llama" )
4757 def test_with_config (self , mock_export_llama : MagicMock ) -> None :
@@ -70,7 +80,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
7080 config_file = f .name
7181
7282 try :
73- test_argv = ["script .py" , "--config" , config_file ]
83+ test_argv = ["export_llm .py" , "--config" , config_file ]
7484 with patch .object (sys , "argv" , test_argv ):
7585 main ()
7686
@@ -99,54 +109,48 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
99109 finally :
100110 os .unlink (config_file )
101111
102- def test_with_cli_args (self ) -> None :
103- """Test main function with only hydra CLI args."""
104- test_argv = ["script.py" , "debug.verbose=True" ]
105- with patch .object (sys , "argv" , test_argv ):
106- with patch (
107- "executorch.extension.llm.export.export_llm.hydra_main"
108- ) as mock_hydra :
109- main ()
110- mock_hydra .assert_called_once ()
111-
112- def test_config_with_cli_args_error (self ) -> None :
113- """Test that --config rejects additional CLI arguments to prevent mixing approaches."""
112+ @patch ("executorch.extension.llm.export.export_llm.export_llama" )
113+ def test_with_config (self , mock_export_llama : MagicMock ) -> None :
114+ """Test main function with --config file and no hydra args."""
114115 # Create a temporary config file
115116 with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as f :
116- f .write ("base:\n checkpoint: /path/to/checkpoint.pth" )
117- config_file = f .name
118-
119- try :
120- test_argv = ["script.py" , "--config" , config_file , "debug.verbose=True" ]
121- with patch .object (sys , "argv" , test_argv ):
122- with self .assertRaises (ValueError ) as cm :
123- main ()
124-
125- error_msg = str (cm .exception )
126- self .assertIn (
127- "Cannot specify additional CLI arguments when using --config" ,
128- error_msg ,
129- )
130- finally :
131- os .unlink (config_file )
132-
133- def test_config_rejects_multiple_cli_args (self ) -> None :
134- """Test that --config rejects multiple CLI arguments (not just single ones)."""
135- with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".yaml" , delete = False ) as f :
136- f .write ("export:\n max_seq_length: 128" )
117+ f .write (
118+ """
119+ base:
120+ model_class: llama2
121+ model:
122+ dtype_override: fp16
123+ backend:
124+ xnnpack:
125+ enabled: False
126+ """
127+ )
137128 config_file = f .name
138129
139130 try :
140131 test_argv = [
141- "script .py" ,
132+ "export_llm .py" ,
142133 "--config" ,
143134 config_file ,
144- "debug.verbose=True " ,
145- "export.output_dir=/tmp " ,
135+ "base.model_class=stories110m " ,
136+ "backend.xnnpack.enabled=True " ,
146137 ]
147138 with patch .object (sys , "argv" , test_argv ):
148- with self .assertRaises (ValueError ):
149- main ()
139+ main ()
140+
141+ # Verify export_llama was called with config
142+ mock_export_llama .assert_called_once ()
143+ called_config = mock_export_llama .call_args [0 ][0 ]
144+ self .assertEqual (
145+ called_config ["base" ]["model_class" ], "stories110m"
146+ ) # Override from CLI.
147+ self .assertEqual (
148+ called_config ["model" ]["dtype_override" ].value , "fp16"
149+ ) # From yaml.
150+ self .assertEqual (
151+ called_config ["backend" ]["xnnpack" ]["enabled" ],
152+ True , # Override from CLI.
153+ )
150154 finally :
151155 os .unlink (config_file )
152156
0 commit comments