|
10 | 10 | import numpy as np
|
11 | 11 | import tensorflow as tf
|
12 | 12 |
|
| 13 | +from tensorflow import keras |
| 14 | +from tensorflow.keras import layers |
13 | 15 | from tensorflow.python.ops import init_ops, random_ops, init_ops
|
14 | 16 | from tensorflow.python.ops.array_ops import FakeQuantWithMinMaxVars
|
| 17 | +from tensorflow.python.framework.graph_util import convert_variables_to_constants |
| 18 | +from tensorflow_model_optimization.quantization.keras import quantize_model |
15 | 19 | from backend_test_base import Tf2OnnxBackendTestBase
|
16 | 20 | from common import unittest_main, check_gru_count, check_opset_after_tf_version, skip_tf2
|
17 | 21 | from tf2onnx.tf_loader import is_tf2
|
|
28 | 32 | dynamic_rnn = tf.nn.dynamic_rnn
|
29 | 33 |
|
30 | 34 |
|
31 |
| -def quantize_model_save(keras_file, tflite_file): |
32 |
| - with quantize.quantize_scope(): |
33 |
| - model = tf.keras.models.load_model(keras_file) |
34 |
| - converter = tf.lite.TFLiteConverter.from_keras_model(model) |
| 35 | +from keras import backend as K |
| 36 | +import tensorflow as tf |
| 37 | + |
| 38 | +from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation |
| 39 | +from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper |
| 40 | +from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry |
| 41 | +from keras2onnx import convert_keras |
| 42 | + |
| 43 | +QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation |
| 44 | +QuantizeWrapper = quantize_wrapper.QuantizeWrapper |
| 45 | +QuantizeRegistry = default_8bit_quantize_registry.QuantizeRegistry |
| 46 | + |
| 47 | +keras = tf.keras |
| 48 | +layers = tf.keras.layers |
35 | 49 |
|
36 |
| - converter.representative_dataset = calibration_gen |
37 |
| - converter._experimental_new_quantizer = True # pylint: disable=protected-access |
38 |
| - converter.target_spec.supported_ops = [ |
39 |
| - tf.lite.OpsSet.TFLITE_BUILTINS_INT8 |
40 |
| - ] # to enable post-training quantization with the representative dataset |
| 50 | +custom_object_scope = tf.keras.utils.custom_object_scope |
| 51 | +deserialize_layer = tf.keras.layers.deserialize |
| 52 | +serialize_layer = tf.keras.layers.serialize |
41 | 53 |
|
42 |
| - tflite_model = converter.convert() |
43 |
| - tflite_file = 'quantized_mnist.tflite' |
44 |
| - open(tflite_file, 'wb').write(tflite_model) |
| 54 | + |
| 55 | + |
| 56 | +def freeze_session(graph, keep_var_names=None, output_names=None, clear_devices=True): |
| 57 | + """ |
| 58 | + Freezes the state of a session into a pruned computation graph. |
| 59 | +
|
| 60 | + Creates a new computation graph where variable nodes are replaced by |
| 61 | + constants taking their current value in the session. The new graph will be |
| 62 | + pruned so subgraphs that are not necessary to compute the requested |
| 63 | + outputs are removed. |
| 64 | + @param graph The TensorFlow graph to be frozen. |
| 65 | + @param keep_var_names A list of variable names that should not be frozen, |
| 66 | + or None to freeze all the variables in the graph. |
| 67 | + @param output_names Names of the relevant graph outputs. |
| 68 | + @param clear_devices Remove the device directives from the graph for better portability. |
| 69 | + @return The frozen graph definition. |
| 70 | + |
| 71 | + Source: https://www.dlology.com/blog/how-to-convert-trained-keras-model-to-tensorflow-and-make-prediction/ |
| 72 | + """ |
| 73 | + with graph.as_default(): |
| 74 | + freeze_var_names = list(set( |
| 75 | + v.op.name for v in tf.global_variables()).difference( |
| 76 | + keep_var_names or [])) |
| 77 | + output_names = output_names or [] |
| 78 | + output_names += [v.op.name for v in tf.global_variables()] |
| 79 | + # Graph -> GraphDef ProtoBuf |
| 80 | + input_graph_def = graph.as_graph_def() |
| 81 | + if clear_devices: |
| 82 | + for node in input_graph_def.node: |
| 83 | + node.device = "" |
| 84 | + frozen_graph = convert_variables_to_constants( |
| 85 | + session, input_graph_def, output_names, freeze_var_names) |
| 86 | + return frozen_graph |
45 | 87 |
|
46 | 88 |
|
47 | 89 | class QuantizationTests(Tf2OnnxBackendTestBase):
|
| 90 | + |
| 91 | + def setUp(self): |
| 92 | + super(QuantizationTests, self).setUp() |
| 93 | + self.quantize_registry = QuantizeRegistry() |
48 | 94 |
|
49 |
| - def common_quantize(self, name): |
50 |
| - dest = os.path.splitext(os.path.split(name)[-1])[0] + '.tflite' |
51 |
| - quantize_model_save(name, dest) |
| 95 | + def test_quantize_keras(self): |
| 96 | + model = quantize_model( |
| 97 | + keras.Sequential([ |
| 98 | + layers.Dense(3, activation='relu', input_shape=(5,)), |
| 99 | + layers.Dense(3, activation='relu', input_shape=(3,)), |
| 100 | + ])) |
| 101 | + model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) |
| 102 | + print(model.summary()) |
| 103 | + x = np.array([[0, 1, 2, 3, 4]], dtype=np.float32) |
| 104 | + y = model.predict(x) |
| 105 | + print(y) |
| 106 | + model_onnx = convert_keras(model) |
| 107 | + print(model_onnx) |
| 108 | + |
| 109 | + def test_quantize_tf(self): |
| 110 | + inputs = tf.keras.layers.Input(shape=(5,)) |
| 111 | + inter = tf.keras.layers.Dense(3, activation='relu')(inputs) |
| 112 | + outputs = tf.keras.layers.Dense(3, activation='relu')(inter) |
| 113 | + model = tf.keras.models.Model(inputs=inputs, outputs=outputs) |
| 114 | + model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) |
| 115 | + model = quantize_model(model) |
| 116 | + print(model.summary()) |
| 117 | + x = np.array([[0, 1, 2, 3, 4]], dtype=np.float32) |
| 118 | + y = model(x) |
| 119 | + print(y) |
| 120 | + print(type(model)) |
52 | 121 |
|
| 122 | + frozen_graph = freeze_session( |
| 123 | + model, |
| 124 | + output_names=[out.op.name for out in model.outputs]) |
| 125 | + |
| 126 | + sess.graph.as_default() |
| 127 | + softmax_tensor = sess.graph.get_tensor_by_name('import/dense_2/Softmax:0') |
| 128 | + predictions = sess.run(softmax_tensor, {'import/conv2d_1_input:0': x}) |
| 129 | + |
| 130 | + model_onnx = convert_keras(model) |
| 131 | + assert model_onnx is not None |
| 132 | + |
| 133 | + def testQuantizesOutputsFromLayer(self): |
| 134 | + # source https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py |
| 135 | + # TODO(pulkitb): Increase coverage by adding other output quantize layers |
| 136 | + # such as AveragePooling etc. |
53 | 137 |
|
54 |
| - def test_fake_quant_with_min_max_vars_gradient(self): |
55 |
| - cwd = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models') |
56 |
| - name = os.path.join(cwd, "gru", "frozen.pb") |
57 |
| - self.common_quantize(name) |
| 138 | + layer = layers.ReLU() |
| 139 | + quantized_model = keras.Sequential([ |
| 140 | + QuantizeWrapper( |
| 141 | + layers.ReLU(), |
| 142 | + quantize_config=self.quantize_registry.get_quantize_config(layer)) |
| 143 | + ]) |
58 | 144 |
|
| 145 | + model = keras.Sequential([layers.ReLU()]) |
59 | 146 |
|
| 147 | + inputs = np.random.rand(1, 2, 1) |
| 148 | + expected_output = tf.quantization.fake_quant_with_min_max_vars( |
| 149 | + model.predict(inputs), -6.0, 6.0, num_bits=8, narrow_range=False) |
| 150 | + exp = quantized_model.predict(inputs) |
| 151 | + self.assertAllClose(expected_output, exp) |
| 152 | + |
| 153 | + model_onnx = convert_keras(model) |
| 154 | + quantized_model_onnx = convert_keras(quantized_model) |
| 155 | + assert model_onnx is not None |
| 156 | + assert quantized_model_onnx is not None |
| 157 | + if 'op_type: "FakeQuantWithMinMaxVars"' in str(quantized_model_onnx): |
| 158 | + raise AssertionError( |
| 159 | + "FakeQuantWithMinMaxVars not replaced\n{}".format(quantized_model_onnx)) |
| 160 | + assert 'op_type: "QuantizeLinear"' in str(quantized_model_onnx) |
| 161 | + assert 'op_type: "DequantizeLinear"' in str(quantized_model_onnx) |
| 162 | + from onnxruntime import InferenceSession |
| 163 | + sess = InferenceSession(quantized_model_onnx.SerializeToString()) |
| 164 | + names = [_.name for _ in sess.get_inputs()] |
| 165 | + got = sess.run(None, {names[0]: x}) |
| 166 | + self.assertAllClose(expected_output, got) |
| 167 | + |
| 168 | + |
60 | 169 | if __name__ == '__main__':
|
61 | 170 | unittest_main()
|
0 commit comments