1
- import sys
2
-
3
- try :
4
- import neural_compressor as inc
5
- print ("neural_compressor version {}" .format (inc .__version__ ))
6
- except :
7
- try :
8
- import lpot as inc
9
- print ("LPOT version {}" .format (inc .__version__ ))
10
- except :
11
- import ilit as inc
12
- print ("iLiT version {}" .format (inc .__version__ ))
13
-
14
- if inc .__version__ == '1.2' :
15
- print ("This script doesn't support LPOT 1.2, please install LPOT 1.1, 1.2.1 or newer" )
16
- sys .exit (1 )
1
+ import neural_compressor as inc
2
+ print ("neural_compressor version {}" .format (inc .__version__ ))
17
3
18
4
import alexnet
19
5
import math
6
+ import yaml
20
7
import mnist_dataset
8
+ from neural_compressor .quantization import fit
9
+ from neural_compressor .config import PostTrainingQuantConfig , TuningCriterion , AccuracyCriterion
21
10
22
11
23
12
def save_int8_frezon_pb (q_model , path ):
24
13
from tensorflow .python .platform import gfile
25
14
f = gfile .GFile (path , 'wb' )
26
- f .write (q_model .as_graph_def ().SerializeToString ())
15
+ f .write (q_model .graph . as_graph_def ().SerializeToString ())
27
16
print ("Save to {}" .format (path ))
28
17
29
18
@@ -44,23 +33,31 @@ def __iter__(self):
44
33
yield x_test [begin :], label_test [begin :]
45
34
46
35
47
- def auto_tune (input_graph_path , yaml_config , batch_size ):
36
+ def auto_tune (input_graph_path , config , batch_size ):
48
37
fp32_graph = alexnet .load_pb (input_graph_path )
49
- quan = inc .Quantization (yaml_config )
50
38
dataloader = Dataloader (batch_size )
51
-
52
- q_model = quan (
53
- fp32_graph ,
54
- q_dataloader = dataloader ,
55
- eval_func = None ,
56
- eval_dataloader = dataloader )
39
+ assert (dataloader )
40
+
41
+ tuning_criterion = TuningCriterion (** config ["tuning_criterion" ])
42
+ accuracy_criterion = AccuracyCriterion (** config ["accuracy_criterion" ])
43
+ q_model = fit (
44
+ model = input_graph_path ,
45
+ conf = PostTrainingQuantConfig (** config ["quant_config" ],
46
+ tuning_criterion = tuning_criterion ,
47
+ accuracy_criterion = accuracy_criterion ,
48
+ ),
49
+ calib_dataloader = dataloader ,
50
+ )
57
51
return q_model
58
52
59
53
60
- yaml_file = "alexnet.yaml"
61
54
batch_size = 200
62
55
fp32_frezon_pb_file = "fp32_frezon.pb"
63
56
int8_pb_file = "alexnet_int8_model.pb"
64
57
65
- q_model = auto_tune (fp32_frezon_pb_file , yaml_file , batch_size )
58
+ with open ("quant_config.yaml" ) as f :
59
+ config = yaml .safe_load (f .read ())
60
+ config
61
+
62
+ q_model = auto_tune (fp32_frezon_pb_file , config , batch_size )
66
63
save_int8_frezon_pb (q_model , int8_pb_file )
0 commit comments