@@ -65,20 +65,17 @@ def load_calibration_dataset(dataset_path: str):
65
65
return calibration_dataset
66
66
67
67
68
- def quantize_model (model : torch .fx .GraphModule , calibration_dataset : torch . utils . data . DataLoader , subset_size = 300 ):
69
- quantizer = OpenVINOQuantizer ()
68
+ def quantize_model (model : torch .fx .GraphModule , example_args , subset_size = 300 ):
69
+ quantizer = OpenVINOQuantizer (ignored_scope = nncf . IgnoredScope ( types = [ "__getitem__" , "layer_norm" ]) )
70
70
71
71
print ("PTQ: Annotate the model..." )
72
72
annotated_model = prepare_pt2e (model , quantizer )
73
73
74
74
print ("PTQ: Calibrate the model..." )
75
- for idx , data in enumerate (calibration_dataset ):
76
- if idx >= subset_size :
77
- break
78
- annotated_model (data [0 ])
75
+ annotated_model (* example_args )
79
76
80
77
print ("PTQ: Convert the quantized model..." )
81
- quantized_model = convert_pt2e (annotated_model )
78
+ quantized_model = convert_pt2e (annotated_model , fold_quantize = False )
82
79
return quantized_model
83
80
84
81
@@ -106,7 +103,9 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
106
103
calibration_dataset = load_calibration_dataset (dataset_path )
107
104
108
105
captured_model = aten_dialect .module ()
109
- quantized_model = quantize_model (captured_model , calibration_dataset )
106
+ visualize_fx_model (captured_model , f"{ model_name } _fp32.svg" )
107
+ quantized_model = quantize_model (captured_model , example_args )
108
+ visualize_fx_model (quantized_model , f"{ model_name } _int8.svg" )
110
109
aten_dialect : ExportedProgram = export (quantized_model , example_args )
111
110
112
111
# Convert to edge dialect
@@ -121,9 +120,15 @@ def main(suite: str, model_name: str, input_shape, quantize: bool, dataset_path:
121
120
exec_prog = lowered_module .to_executorch (config = executorch .exir .ExecutorchBackendConfig ())
122
121
123
122
# Serialize and save it to a file
124
- with open (f"{ model_name } .pte" , "wb" ) as file :
123
+ model_name = f"{ model_name } _{ 'int8' if quantize else 'fp32' } .pte"
124
+ with open (model_name , "wb" ) as file :
125
125
exec_prog .write_to_file (file )
126
- print (f"Model exported and saved as { model_name } .pte on { device } ." )
126
+ print (f"Model exported and saved as { model_name } on { device } ." )
127
+
128
+ from torch .fx .passes .graph_drawer import FxGraphDrawer
129
+ def visualize_fx_model (model : torch .fx .GraphModule , output_svg_path : str ):
130
+ g = FxGraphDrawer (model , output_svg_path )
131
+ g .get_dot_graph ().write_svg (output_svg_path )
127
132
128
133
if __name__ == "__main__" :
129
134
# Argument parser for dynamic inputs
0 commit comments