Skip to content

Commit aaaa2fc

Browse files
committed
loosen batchnorm merging restrictions, fix ternary handling
1 parent c9693da commit aaaa2fc

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

hls4ml/model/optimizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
'merge_linear_activation',
6262
'fuse_batch_normalization',
6363
'eliminate_linear_activation',
64+
'qkeras_factorize_alpha',
65+
'extract_ternary_threshold',
6466
# The ones above here need to be before infer_precision_types
6567
'infer_precision_types',
6668
'channels_last_converter',
@@ -70,8 +72,6 @@
7072
'fuse_bias_add',
7173
'expand_layer_group',
7274
'output_rounding_saturation_mode',
73-
'qkeras_factorize_alpha',
74-
'extract_ternary_threshold',
7575
],
7676
requires=['parse_qonnx'],
7777
) # TODO Maybe not all QKeras optmizers belong here?

hls4ml/model/optimizer/passes/batchnorm_opt.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,12 @@ def match(self, node):
170170
s1 = node.weights['scale'].data_unquantized
171171
b1 = node.weights['bias'].data_unquantized
172172
scale_compatible = (
173-
(prev_node.get_attr('scale_quantizer') is None and node.get_attr('scale_quantizer') is None)
173+
(prev_node.get_attr('scale_quantizer') is None or node.get_attr('scale_quantizer') is None)
174174
or (s0 == np.ones_like(s0)).all()
175175
or (s1 == np.ones_like(s1)).all()
176176
)
177177
bias_compatible = (
178-
(prev_node.get_attr('bias_quantizer') is None and node.get_attr('bias_quantizer') is None)
178+
(prev_node.get_attr('bias_quantizer') is None or node.get_attr('bias_quantizer') is None)
179179
or (b0 == np.zeros_like(b0)).all()
180180
or (b1 == np.zeros_like(b1)).all()
181181
)
@@ -195,26 +195,24 @@ def transform(self, model, node):
195195
# if len(node_map[node.outputs[0]]) > 1:
196196
# return False
197197

198-
# only merge if the types are integer or fixed
199-
if (
200-
not isinstance(prev_node.weights['scale'].type.precision, (IntegerPrecisionType, FixedPrecisionType))
201-
or not isinstance(prev_node.weights['bias'].type.precision, (IntegerPrecisionType, FixedPrecisionType))
202-
or not isinstance(node.weights['scale'].type.precision, (IntegerPrecisionType, FixedPrecisionType))
203-
or not isinstance(node.weights['bias'].type.precision, (IntegerPrecisionType, FixedPrecisionType))
204-
):
205-
return False
206-
207198
s0 = prev_node.weights['scale'].data_unquantized
208199
b0 = prev_node.weights['bias'].data_unquantized
209200
s1 = node.weights['scale'].data_unquantized
210201
b1 = node.weights['bias'].data_unquantized
211202

212-
s_quantizer = (
213-
node.get_attr('scale_quantizer') if (s0 == np.ones_like(s0)).all() else prev_node.get_attr('scale_quantizer')
214-
)
215-
b_quantizer = (
216-
node.get_attr('bias_quantizer') if (b0 == np.zeros_like(b0)).all() else prev_node.get_attr('bias_quantizer')
217-
)
203+
if (s0 == np.ones_like(s0)).all():
204+
s_quantizer = node.get_attr('scale_quantizer')
205+
elif (s1 == np.ones_like(s1)).all():
206+
s_quantizer = prev_node.get_attr('scale_quantizer')
207+
else:
208+
s_quantizer = None
209+
210+
if (b0 == np.ones_like(b0)).all():
211+
b_quantizer = node.get_attr('bias_quantizer')
212+
elif (b1 == np.ones_like(b1)).all():
213+
b_quantizer = prev_node.get_attr('bias_quantizer')
214+
else:
215+
b_quantizer = None
218216

219217
node.set_attr('scale_quantizer', s_quantizer)
220218
node.set_attr('bias_quantizer', b_quantizer)

hls4ml/model/optimizer/passes/qkeras.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,16 @@ def transform(self, model, node):
163163
else:
164164
n_in = node.get_attr('n_out')
165165

166+
# the name of the new ApplyAlpha node
167+
alpha_name = node.get_attr('name') + '_alpha'
168+
169+
# make the precision auto
170+
alpha_precision = {'Precision': 'auto'}
171+
model.config.set_name_config(alpha_name, alpha_precision)
172+
model.config.parse_name_config(alpha_name, alpha_precision)
173+
166174
attrs = {
167-
'name': node.get_attr('name') + '_alpha',
175+
'name': alpha_name,
168176
'class_name': 'Alpha',
169177
'inputs': node.outputs,
170178
'n_in': n_in,

test/pytest/test_qkeras.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,10 @@ def test_relu_negative_slope(randX_1000_1, quantizer, backend, io_type):
356356
],
357357
)
358358
def test_qactivation_kwarg(randX_100_10, activation_quantizer, weight_quantizer):
359-
if activation_quantizer in ['binary', 'ternary']:
359+
if activation_quantizer in ['binary']:
360360
name = 'bnbt_qdense_alpha'
361+
elif activation_quantizer in ['ternary']:
362+
name = 'bnbt_qdense_ternary_scale'
361363
else:
362364
name = f'qdense_{eval(activation_quantizer).__class__.__name__}'
363365

0 commit comments

Comments
 (0)