Skip to content

Commit c320f50

Browse files
authored
Merge pull request #1132 from jurevreca12/fix-FuseQuantWithConstant
Fix problem with scale being a multidimensional array.
2 parents e778ed3 + 4eb0746 commit c320f50

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

hls4ml/model/optimizer/passes/quant_opt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def match(self, node):
167167
scale_unit_or_po2 = (scale == np.ones_like(scale)).all()
168168
if not scale_unit_or_po2 and _ALSO_MATCH_PO2:
169169
# This optimization only works if all scales are the same
170-
if np.all(scale[0] == scale):
171-
mantissa, _ = np.frexp(scale[0])
170+
if np.all(scale.item(0) == scale):
171+
mantissa, _ = np.frexp(scale.item(0))
172172
scale_unit_or_po2 = mantissa == 0.5
173173

174174
is_match = scale_unit_or_po2
@@ -187,7 +187,7 @@ def transform(self, model, node):
187187
integer = bitwidth
188188
scale = node.get_attr('scale')
189189
if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all():
190-
_, exp = np.frexp(scale[0]) # know that np.all(scale[0] == scale) must be true
190+
_, exp = np.frexp(scale.item(0)) # know that np.all(scale.item(0) == scale) must be true
191191
integer = bitwidth + exp - 1
192192

193193
precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode)

0 commit comments

Comments
 (0)