@@ -88,23 +88,23 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
8888 mock_export_llama .assert_called_once ()
8989 called_config = mock_export_llama .call_args [0 ][0 ]
9090 self .assertEqual (
91- called_config [ " base" ][ " tokenizer_path" ] , "/path/to/tokenizer.json"
91+ called_config . base . tokenizer_path , "/path/to/tokenizer.json"
9292 )
93- self .assertEqual (called_config [ " base" ][ " model_class" ] , "llama2" )
94- self .assertEqual (called_config [ " base" ][ " preq_mode" ] .value , "8da4w" )
95- self .assertEqual (called_config [ " model" ][ " dtype_override" ] .value , "fp16" )
96- self .assertEqual (called_config [ " export" ][ " max_seq_length" ] , 256 )
93+ self .assertEqual (called_config . base . model_class , "llama2" )
94+ self .assertEqual (called_config . base . preq_mode .value , "8da4w" )
95+ self .assertEqual (called_config . model . dtype_override .value , "fp16" )
96+ self .assertEqual (called_config . export . max_seq_length , 256 )
9797 self .assertEqual (
98- called_config [ " quantization" ][ " pt2e_quantize" ] .value , "xnnpack_dynamic"
98+ called_config . quantization . pt2e_quantize .value , "xnnpack_dynamic"
9999 )
100100 self .assertEqual (
101- called_config [ " quantization" ][ " use_spin_quant" ] .value , "cuda"
101+ called_config . quantization . use_spin_quant .value , "cuda"
102102 )
103103 self .assertEqual (
104- called_config [ " backend" ][ " coreml" ][ " quantize" ] .value , "c4w"
104+ called_config . backend . coreml . quantize .value , "c4w"
105105 )
106106 self .assertEqual (
107- called_config [ " backend" ][ " coreml" ][ " compute_units" ] .value , "cpu_and_gpu"
107+ called_config . backend . coreml . compute_units .value , "cpu_and_gpu"
108108 )
109109 finally :
110110 os .unlink (config_file )
0 commit comments