|
| 1 | +import sys |
| 2 | + |
1 | 3 | import neural_compressor as inc
|
2 | 4 | print("neural_compressor version {}".format(inc.__version__))
|
3 | 5 |
|
4 |
| -import alexnet |
5 |
| -import math |
6 |
| -import yaml |
7 |
| -import mnist_dataset |
| 6 | +import tensorflow as tf |
| 7 | +print("tensorflow {}".format(tf.__version__)) |
| 8 | + |
| 9 | +from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion, TuningCriterion |
| 10 | +from neural_compressor.data import DataLoader |
8 | 11 | from neural_compressor.quantization import fit
|
9 |
| -from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion, AccuracyCriterion |
| 12 | +from neural_compressor import Metric |
10 | 13 |
|
| 14 | +import mnist_dataset |
11 | 15 |
|
12 |
| -def save_int8_frezon_pb(q_model, path): |
13 |
| - from tensorflow.python.platform import gfile |
14 |
| - f = gfile.GFile(path, 'wb') |
15 |
| - f.write(q_model.graph.as_graph_def().SerializeToString()) |
16 |
| - print("Save to {}".format(path)) |
17 | 16 |
|
| 17 | +class Dataset(object): |
| 18 | + def __init__(self): |
| 19 | + _x_train, _y_train, label_train, x_test, y_test, label_test = mnist_dataset.read_data() |
18 | 20 |
|
19 |
| -class Dataloader(object): |
20 |
| - def __init__(self, batch_size): |
21 |
| - self.batch_size = batch_size |
| 21 | + self.test_images = x_test |
| 22 | + self.labels = label_test |
22 | 23 |
|
23 |
| - def __iter__(self): |
24 |
| - x_train, y_train, label_train, x_test, y_test, label_test = mnist_dataset.read_data() |
25 |
| - batch_nums = math.ceil(len(x_test) / self.batch_size) |
| 24 | + def __getitem__(self, index): |
| 25 | + return self.test_images[index], self.labels[index] |
26 | 26 |
|
27 |
| - for i in range(batch_nums - 1): |
28 |
| - begin = i * self.batch_size |
29 |
| - end = (i + 1) * self.batch_size |
30 |
| - yield x_test[begin: end], label_test[begin: end] |
| 27 | + def __len__(self): |
| 28 | + return len(self.test_images) |
31 | 29 |
|
32 |
| - begin = (batch_nums - 1) * self.batch_size |
33 |
| - yield x_test[begin:], label_test[begin:] |
34 | 30 |
|
| 31 | +def auto_tune(input_graph_path, batch_size): |
| 32 | + dataset = Dataset() |
| 33 | + dataloader = DataLoader(framework='tensorflow', dataset=dataset, batch_size=batch_size) |
| 34 | + tuning_criterion = TuningCriterion(max_trials=100) |
| 35 | + config = PostTrainingQuantConfig(approach="static", tuning_criterion=tuning_criterion, |
| 36 | + accuracy_criterion = AccuracyCriterion( |
| 37 | + higher_is_better=True, |
| 38 | + criterion='relative', |
| 39 | + tolerable_loss=0.01 ) |
| 40 | + ) |
| 41 | + top1 = Metric(name="topk", k=1) |
35 | 42 |
|
36 |
| -def auto_tune(input_graph_path, config, batch_size): |
37 |
| - fp32_graph = alexnet.load_pb(input_graph_path) |
38 |
| - dataloader = Dataloader(batch_size) |
39 |
| - assert(dataloader) |
40 |
| - |
41 |
| - tuning_criterion = TuningCriterion(**config["tuning_criterion"]) |
42 |
| - accuracy_criterion = AccuracyCriterion(**config["accuracy_criterion"]) |
43 | 43 | q_model = fit(
|
44 |
| - model=input_graph_path, |
45 |
| - conf=PostTrainingQuantConfig(**config["quant_config"], |
46 |
| - tuning_criterion=tuning_criterion, |
47 |
| - accuracy_criterion=accuracy_criterion, |
48 |
| - ), |
49 |
| - calib_dataloader=dataloader, |
| 44 | + model=input_graph_path, |
| 45 | + conf=config, |
| 46 | + calib_dataloader=dataloader, |
| 47 | + eval_dataloader=dataloader, |
| 48 | + eval_metric=top1 |
50 | 49 | )
|
| 50 | + |
| 51 | + |
51 | 52 | return q_model
|
52 | 53 |
|
53 | 54 |
|
54 | 55 | batch_size = 200
|
55 |
| -fp32_frezon_pb_file = "fp32_frezon.pb" |
| 56 | +fp32_frozen_pb_file = "fp32_frozen.pb" |
56 | 57 | int8_pb_file = "alexnet_int8_model.pb"
|
57 | 58 |
|
58 |
| -with open("quant_config.yaml") as f: |
59 |
| - config = yaml.safe_load(f.read()) |
60 |
| -config |
61 |
| - |
62 |
| -q_model = auto_tune(fp32_frezon_pb_file, config, batch_size) |
63 |
| -save_int8_frezon_pb(q_model, int8_pb_file) |
| 59 | +q_model = auto_tune(fp32_frozen_pb_file, batch_size) |
| 60 | +q_model.save(int8_pb_file) |
0 commit comments