Skip to content

Commit e64b5d0

Browse files
committed
Create test_quantization.py
1 parent d91315c commit e64b5d0

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/test_quantization.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Unit Tests for custom rnns."""
5+
6+
from __future__ import absolute_import
7+
from __future__ import division
8+
from __future__ import print_function
9+
import os
10+
import numpy as np
11+
import tensorflow as tf
12+
13+
from tensorflow.python.ops import init_ops, random_ops, init_ops
14+
from tensorflow.python.ops.array_ops import FakeQuantWithMinMaxVars
15+
from backend_test_base import Tf2OnnxBackendTestBase
16+
from common import unittest_main, check_gru_count, check_opset_after_tf_version, skip_tf2
17+
from tf2onnx.tf_loader import is_tf2
18+
from tensorflow_model_optimization.python.core.quantization.keras import quantize
19+
20+
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
21+
# pylint: disable=abstract-method,arguments-differ
22+
23+
if is_tf2():
24+
fake_quant_with_min_max_vars_gradient = tf.compat.v1.quantization.fake_quant_with_min_max_vars_gradient
25+
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
26+
else:
27+
fake_quant_with_min_max_vars_gradient = tf.quantization.fake_quant_with_min_max_vars_gradient
28+
dynamic_rnn = tf.nn.dynamic_rnn
29+
30+
31+
def quantize_model_save(keras_file, tflite_file):
32+
with quantize.quantize_scope():
33+
model = tf.keras.models.load_model(keras_file)
34+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
35+
36+
converter.representative_dataset = calibration_gen
37+
converter._experimental_new_quantizer = True # pylint: disable=protected-access
38+
converter.target_spec.supported_ops = [
39+
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
40+
] # to enable post-training quantization with the representative dataset
41+
42+
tflite_model = converter.convert()
43+
tflite_file = 'quantized_mnist.tflite'
44+
open(tflite_file, 'wb').write(tflite_model)
45+
46+
47+
class QuantizationTests(Tf2OnnxBackendTestBase):
48+
49+
def common_quantize(self, name):
50+
dest = os.path.splitext(os.path.split(name)[-1])[0] + '.tflite'
51+
quantize_model_save(name, dest)
52+
53+
54+
def test_fake_quant_with_min_max_vars_gradient(self):
55+
cwd = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
56+
name = os.path.join(cwd, "gru", "frozen.pb")
57+
self.common_quantize(name)
58+
59+
60+
if __name__ == '__main__':
61+
unittest_main()

0 commit comments

Comments
 (0)