Skip to content

Commit eca8647

Browse files
committed
Replace operator FakeQuantWithMinMaxVars
1 parent 1cb1077 commit eca8647

File tree

1 file changed

+128
-19
lines changed

1 file changed

+128
-19
lines changed

tests/test_quantization.py

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
import numpy as np
1111
import tensorflow as tf
1212

13+
from tensorflow import keras
14+
from tensorflow.keras import layers
1315
from tensorflow.python.ops import init_ops, random_ops, init_ops
1416
from tensorflow.python.ops.array_ops import FakeQuantWithMinMaxVars
17+
from tensorflow.python.framework.graph_util import convert_variables_to_constants
18+
from tensorflow_model_optimization.quantization.keras import quantize_model
1519
from backend_test_base import Tf2OnnxBackendTestBase
1620
from common import unittest_main, check_gru_count, check_opset_after_tf_version, skip_tf2
1721
from tf2onnx.tf_loader import is_tf2
@@ -28,34 +32,139 @@
2832
dynamic_rnn = tf.nn.dynamic_rnn
2933

3034

31-
def quantize_model_save(keras_file, tflite_file):
32-
with quantize.quantize_scope():
33-
model = tf.keras.models.load_model(keras_file)
34-
converter = tf.lite.TFLiteConverter.from_keras_model(model)
35+
from keras import backend as K
36+
import tensorflow as tf
37+
38+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
39+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper
40+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
41+
from keras2onnx import convert_keras
42+
43+
QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
44+
QuantizeWrapper = quantize_wrapper.QuantizeWrapper
45+
QuantizeRegistry = default_8bit_quantize_registry.QuantizeRegistry
46+
47+
keras = tf.keras
48+
layers = tf.keras.layers
3549

36-
converter.representative_dataset = calibration_gen
37-
converter._experimental_new_quantizer = True # pylint: disable=protected-access
38-
converter.target_spec.supported_ops = [
39-
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
40-
] # to enable post-training quantization with the representative dataset
50+
custom_object_scope = tf.keras.utils.custom_object_scope
51+
deserialize_layer = tf.keras.layers.deserialize
52+
serialize_layer = tf.keras.layers.serialize
4153

42-
tflite_model = converter.convert()
43-
tflite_file = 'quantized_mnist.tflite'
44-
open(tflite_file, 'wb').write(tflite_model)
54+
55+
56+
def freeze_session(graph, keep_var_names=None, output_names=None, clear_devices=True):
57+
"""
58+
Freezes the state of a session into a pruned computation graph.
59+
60+
Creates a new computation graph where variable nodes are replaced by
61+
constants taking their current value in the session. The new graph will be
62+
pruned so subgraphs that are not necessary to compute the requested
63+
outputs are removed.
64+
@param graph The TensorFlow graph to be frozen.
65+
@param keep_var_names A list of variable names that should not be frozen,
66+
or None to freeze all the variables in the graph.
67+
@param output_names Names of the relevant graph outputs.
68+
@param clear_devices Remove the device directives from the graph for better portability.
69+
@return The frozen graph definition.
70+
71+
Source: https://www.dlology.com/blog/how-to-convert-trained-keras-model-to-tensorflow-and-make-prediction/
72+
"""
73+
with graph.as_default():
74+
freeze_var_names = list(set(
75+
v.op.name for v in tf.global_variables()).difference(
76+
keep_var_names or []))
77+
output_names = output_names or []
78+
output_names += [v.op.name for v in tf.global_variables()]
79+
# Graph -> GraphDef ProtoBuf
80+
input_graph_def = graph.as_graph_def()
81+
if clear_devices:
82+
for node in input_graph_def.node:
83+
node.device = ""
84+
frozen_graph = convert_variables_to_constants(
85+
session, input_graph_def, output_names, freeze_var_names)
86+
return frozen_graph
4587

4688

4789
class QuantizationTests(Tf2OnnxBackendTestBase):
90+
91+
def setUp(self):
92+
super(QuantizationTests, self).setUp()
93+
self.quantize_registry = QuantizeRegistry()
4894

49-
def common_quantize(self, name):
50-
dest = os.path.splitext(os.path.split(name)[-1])[0] + '.tflite'
51-
quantize_model_save(name, dest)
95+
def test_quantize_keras(self):
96+
model = quantize_model(
97+
keras.Sequential([
98+
layers.Dense(3, activation='relu', input_shape=(5,)),
99+
layers.Dense(3, activation='relu', input_shape=(3,)),
100+
]))
101+
model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
102+
print(model.summary())
103+
x = np.array([[0, 1, 2, 3, 4]], dtype=np.float32)
104+
y = model.predict(x)
105+
print(y)
106+
model_onnx = convert_keras(model)
107+
print(model_onnx)
108+
109+
def test_quantize_tf(self):
110+
inputs = tf.keras.layers.Input(shape=(5,))
111+
inter = tf.keras.layers.Dense(3, activation='relu')(inputs)
112+
outputs = tf.keras.layers.Dense(3, activation='relu')(inter)
113+
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
114+
model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
115+
model = quantize_model(model)
116+
print(model.summary())
117+
x = np.array([[0, 1, 2, 3, 4]], dtype=np.float32)
118+
y = model(x)
119+
print(y)
120+
print(type(model))
52121

122+
frozen_graph = freeze_session(
123+
model,
124+
output_names=[out.op.name for out in model.outputs])
125+
126+
sess.graph.as_default()
127+
softmax_tensor = sess.graph.get_tensor_by_name('import/dense_2/Softmax:0')
128+
predictions = sess.run(softmax_tensor, {'import/conv2d_1_input:0': x})
129+
130+
model_onnx = convert_keras(model)
131+
assert model_onnx is not None
132+
133+
def testQuantizesOutputsFromLayer(self):
134+
# source https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py
135+
# TODO(pulkitb): Increase coverage by adding other output quantize layers
136+
# such as AveragePooling etc.
53137

54-
def test_fake_quant_with_min_max_vars_gradient(self):
55-
cwd = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
56-
name = os.path.join(cwd, "gru", "frozen.pb")
57-
self.common_quantize(name)
138+
layer = layers.ReLU()
139+
quantized_model = keras.Sequential([
140+
QuantizeWrapper(
141+
layers.ReLU(),
142+
quantize_config=self.quantize_registry.get_quantize_config(layer))
143+
])
58144

145+
model = keras.Sequential([layers.ReLU()])
59146

147+
inputs = np.random.rand(1, 2, 1)
148+
expected_output = tf.quantization.fake_quant_with_min_max_vars(
149+
model.predict(inputs), -6.0, 6.0, num_bits=8, narrow_range=False)
150+
exp = quantized_model.predict(inputs)
151+
self.assertAllClose(expected_output, exp)
152+
153+
model_onnx = convert_keras(model)
154+
quantized_model_onnx = convert_keras(quantized_model)
155+
assert model_onnx is not None
156+
assert quantized_model_onnx is not None
157+
if 'op_type: "FakeQuantWithMinMaxVars"' in str(quantized_model_onnx):
158+
raise AssertionError(
159+
"FakeQuantWithMinMaxVars not replaced\n{}".format(quantized_model_onnx))
160+
assert 'op_type: "QuantizeLinear"' in str(quantized_model_onnx)
161+
assert 'op_type: "DequantizeLinear"' in str(quantized_model_onnx)
162+
from onnxruntime import InferenceSession
163+
sess = InferenceSession(quantized_model_onnx.SerializeToString())
164+
names = [_.name for _ in sess.get_inputs()]
165+
got = sess.run(None, {names[0]: x})
166+
self.assertAllClose(expected_output, got)
167+
168+
60169
if __name__ == '__main__':
61170
unittest_main()

0 commit comments

Comments
 (0)