Skip to content

Commit c8e1857

Browse files
authored
Merge pull request #1142 from fastmachinelearning/qonnx_warnings
Qonnx warnings
2 parents ce7f1f1 + 915d2e1 commit c8e1857

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

hls4ml/model/optimizer/passes/batchnorm_opt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class FuseConsecutiveBatchNormalization(OptimizerPass):
166166
"""
167167

168168
def match(self, node):
169-
prev_node = node.get_input_node(node.inputs[0])
169+
prev_node = node.get_input_node()
170170
basic_match = (
171171
isinstance(node, BatchNormalization)
172172
and isinstance(prev_node, BatchNormalization)
@@ -194,7 +194,7 @@ def match(self, node):
194194
return False
195195

196196
def transform(self, model, node):
197-
prev_node = node.get_input_node(node.inputs[0])
197+
prev_node = node.get_input_node()
198198

199199
prev_map = prev_node.get_output_use_map()
200200
if len(prev_map[prev_node.outputs[0]]) > 1:

hls4ml/model/optimizer/passes/bn_fuse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class FuseBatchNormalization(OptimizerPass):
1818
"""
1919

2020
def match(self, node):
21-
prev_node = node.get_input_node(node.inputs[0])
21+
prev_node = node.get_input_node()
2222
basic_match = (
2323
isinstance(node, BatchNormalization)
2424
and isinstance(prev_node, (Dense, Conv1D, Conv2D))

hls4ml/model/optimizer/passes/move_scales.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
66
'''
77

8+
import warnings
9+
810
import numpy as np
911

1012
from hls4ml.model.layers import ApplyAlpha, Constant, Conv, MatMul, Merge
@@ -85,6 +87,9 @@ def transform(self, model, node):
8587
can_propagate = False
8688

8789
if not can_propagate:
90+
warnings.warn(
91+
'Failed to propagate quantization scales down MatMul node; model probably not suppored.', stacklevel=1
92+
)
8893
return False
8994

9095
model.remove_node(apply_alpha)
@@ -124,6 +129,9 @@ def transform(self, model, node):
124129
try:
125130
bias = bias0 + bias1
126131
except ValueError:
132+
warnings.warn(
133+
'Failed to propagate quantization scales down Add node; model probably not suppored.', stacklevel=1
134+
)
127135
return False
128136

129137
model.remove_node(in0)
@@ -169,6 +177,7 @@ def transform(self, model, node):
169177
model.insert_node(new_node)
170178
return True
171179
else:
180+
warnings.warn('Failed to propagate quantization bias down Add node; model probably not suppored.', stacklevel=1)
172181
return False
173182

174183

@@ -243,6 +252,9 @@ def transform(self, model, node):
243252
except ValueError:
244253
can_propagate = False
245254
if not can_propagate:
255+
warnings.warn(
256+
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
257+
)
246258
return False
247259

248260
# to remove warning, since these get set again
@@ -287,6 +299,9 @@ def transform(self, model, node):
287299
except ValueError:
288300
can_propagate = False
289301
if not can_propagate:
302+
warnings.warn(
303+
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
304+
)
290305
return False
291306

292307
# to remove warning, since these get set again
@@ -308,6 +323,9 @@ def transform(self, model, node):
308323
can_propagate = False
309324

310325
if not can_propagate:
326+
warnings.warn(
327+
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
328+
)
311329
return False
312330

313331
# to remove warning, since these get set again
@@ -367,6 +385,9 @@ def transform(self, model, node):
367385
except ValueError:
368386
can_propagate = False
369387
if not can_propagate:
388+
warnings.warn(
389+
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
390+
)
370391
return False
371392

372393
# to remove warning, since these get set again
@@ -388,6 +409,9 @@ def transform(self, model, node):
388409
except ValueError:
389410
can_propagate = False
390411
if not can_propagate:
412+
warnings.warn(
413+
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
414+
)
391415
return False
392416

393417
# to remove warning, since these get set again
@@ -412,6 +436,9 @@ def transform(self, model, node):
412436
except ValueError:
413437
can_propagate = False
414438
if not can_propagate:
439+
warnings.warn(
440+
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
441+
)
415442
return False
416443

417444
# to remove warning, since these get set again
@@ -445,6 +472,9 @@ def transform(self, model, node):
445472
except ValueError:
446473
can_propagate = False
447474
if not can_propagate:
475+
warnings.warn(
476+
'Failed to propagate quantization scales down Conv node; model probably not suppored.', stacklevel=1
477+
)
448478
return False
449479

450480
# to remove warning, since these get set again

0 commit comments

Comments
 (0)