Skip to content

Commit 78ea90f

Browse files
lishanokcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 770873511 Change-Id: I1219eeb60f5b4e1dee3c5f4ea82bd8a3c8d92a01
1 parent 378143c commit 78ea90f

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

qkeras/quantizers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,6 +2260,11 @@ def get_config(self):
22602260
return config
22612261

22622262

2263+
@tf.function(jit_compile=True)
2264+
def fast_relu_quantize(p, m_i, factor):
2265+
return m_i * tf.clip_by_value(tf.round(p) * factor, 0.0, 1.0 - factor)
2266+
2267+
22632268
@quantizer_registry.register_quantizer
22642269
class quantized_relu(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name
22652270
"""Computes a quantized relu to a number of bits.
@@ -2359,17 +2364,13 @@ def __str__(self):
23592364
flags.append(str(int(self.use_stochastic_rounding)))
23602365
return "quantized_relu(" + ",".join(flags) + ")"
23612366

2362-
@tf.function(jit_compile=True)
2363-
def fast_quantize(p, m_i, factor):
2364-
return m_i * tf.clip_by_value(tf.round(p) * factor, 0.0, 1.0 - factor)
2365-
23662367
def __call__(self, x):
23672368
if self.enable_fast_inference:
23682369
# This is the fast inference version of the quantizer.
23692370
m_i = 1 << self.integer
23702371
p = x * (2 ** (self.bits - self.integer))
23712372
factor = 2 ** -self.bits
2372-
return self.fast_quantize(p, m_i, factor)
2373+
return fast_relu_quantize(p, m_i, factor)
23732374

23742375
if not self.built:
23752376
self.build(var_name=self.var_name, use_variables=self.use_variables)

tests/qactivation_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,5 +813,12 @@ def test_quantized_hswish(bits, integer, symmetric, relu_shift,
813813
assert_allclose(result, expected_values, rtol=1e-05)
814814

815815

816+
def test_quantized_relu_fast_inference():
817+
q1 = quantized_relu(10, 2, enable_fast_inference=False)
818+
q2 = quantized_relu(10, 2, enable_fast_inference=True)
819+
x = np.array([-2.1, 0.73, 2.36, 4.98])
820+
np.testing.assert_array_equal(q1(x).numpy(), q2(x).numpy())
821+
822+
816823
if __name__ == '__main__':
817824
pytest.main([__file__])

0 commit comments

Comments
 (0)