@@ -88,23 +88,19 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
8888 mock_export_llama .assert_called_once ()
8989 called_config = mock_export_llama .call_args [0 ][0 ]
9090 self .assertEqual (
91- called_config [ " base" ][ " tokenizer_path" ] , "/path/to/tokenizer.json"
91+ called_config . base . tokenizer_path , "/path/to/tokenizer.json"
9292 )
93- self .assertEqual (called_config [ " base" ][ " model_class" ] , "llama2" )
94- self .assertEqual (called_config [ " base" ][ " preq_mode" ] .value , "8da4w" )
95- self .assertEqual (called_config [ " model" ][ " dtype_override" ] .value , "fp16" )
96- self .assertEqual (called_config [ " export" ][ " max_seq_length" ] , 256 )
93+ self .assertEqual (called_config . base . model_class , "llama2" )
94+ self .assertEqual (called_config . base . preq_mode .value , "8da4w" )
95+ self .assertEqual (called_config . model . dtype_override .value , "fp16" )
96+ self .assertEqual (called_config . export . max_seq_length , 256 )
9797 self .assertEqual (
98- called_config [ " quantization" ][ " pt2e_quantize" ] .value , "xnnpack_dynamic"
98+ called_config . quantization . pt2e_quantize .value , "xnnpack_dynamic"
9999 )
100+ self .assertEqual (called_config .quantization .use_spin_quant .value , "cuda" )
101+ self .assertEqual (called_config .backend .coreml .quantize .value , "c4w" )
100102 self .assertEqual (
101- called_config ["quantization" ]["use_spin_quant" ].value , "cuda"
102- )
103- self .assertEqual (
104- called_config ["backend" ]["coreml" ]["quantize" ].value , "c4w"
105- )
106- self .assertEqual (
107- called_config ["backend" ]["coreml" ]["compute_units" ].value , "cpu_and_gpu"
103+ called_config .backend .coreml .compute_units .value , "cpu_and_gpu"
108104 )
109105 finally :
110106 os .unlink (config_file )
@@ -142,13 +138,13 @@ def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None:
142138 mock_export_llama .assert_called_once ()
143139 called_config = mock_export_llama .call_args [0 ][0 ]
144140 self .assertEqual (
145- called_config [ " base" ][ " model_class" ] , "stories110m"
141+ called_config . base . model_class , "stories110m"
146142 ) # Override from CLI.
147143 self .assertEqual (
148- called_config [ " model" ][ " dtype_override" ] .value , "fp16"
144+ called_config . model . dtype_override .value , "fp16"
149145 ) # From yaml.
150146 self .assertEqual (
151- called_config [ " backend" ][ " xnnpack" ][ " enabled" ] ,
147+ called_config . backend . xnnpack . enabled ,
152148 True , # Override from CLI.
153149 )
154150 finally :
0 commit comments