Skip to content

Commit 697e825

Browse files
committed
SepConv1D fixes
1 parent 1f9668d commit 697e825

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,17 @@ def format(self, node):
275275

276276
# Depthwise config
277277
params = self._default_config_params(node)
278+
# Override bias and bias_t since these are zeros in depthwise step of SepConv1D
279+
params['bias'] = params['zero_bias']
280+
params['bias_t'] = params['zero_bias_t']
278281
params['n_filt'] = params['n_chan'] # In depthwise step n_chan == n_filt
279282
params['dilation'] = node.get_attr('dilation', 1)
280283
params['nzeros'] = node.get_weights('depthwise').nzeros
281284
params['index'] = str(node.index) + '_depthwise'
282285
params['weight_t'] = node.get_weights('depthwise').type
283286
params['fill_fn'] = 'FillConv1DBuffer'
284287

285-
if node.get_attr("unscaled"):
288+
if node.get_attr('unscaled'):
286289
params['scale_index_type'] = 'scale_index_unscaled'
287290
else:
288291
params['scale_index_type'] = 'scale_index_regular'
@@ -303,14 +306,11 @@ def format(self, node):
303306
depthwise_mult_config = self.depthwise_mult_template.format(**mult_params)
304307

305308
# Pointwise config
306-
params = self._default_config_params()
307-
input_shape = self.get_input_variable().shape
308-
if self.get_attr('data_format') == 'channels_last':
309-
params['in_width'] = '*'.join([str(k) for k in input_shape[:-1]])
310-
params['n_chan'] = input_shape[-1]
309+
params = self._default_config_params(node)
310+
if node.get_attr('data_format') == 'channels_last':
311+
params['in_width'] = node.get_output_variable().shape[0]
311312
else:
312-
params['in_width'] = '*'.join([str(k) for k in input_shape[1:]])
313-
params['n_chan'] = input_shape[0]
313+
params['in_width'] = node.get_output_variable().shape[1]
314314

315315
params['filt_width'] = 1
316316
params['stride_width'] = 1
@@ -322,7 +322,7 @@ def format(self, node):
322322
params['instructions'] = '0'
323323
params['fill_fn'] = 'FillConv1DBuffer'
324324

325-
if node.get_attr("unscaled"):
325+
if node.get_attr('unscaled'):
326326
params['scale_index_type'] = 'scale_index_unscaled'
327327
else:
328328
params['scale_index_type'] = 'scale_index_regular'
@@ -401,12 +401,12 @@ def format(self, node):
401401
params['weight_t'] = node.get_weights('depthwise').type
402402
params['fill_fn'] = 'FillConv2DBuffer'
403403

404-
if node.get_attr("unscaled_h"):
404+
if node.get_attr('unscaled_h'):
405405
params['scale_index_height_type'] = 'scale_index_unscaled'
406406
else:
407407
params['scale_index_height_type'] = 'scale_index_regular'
408408

409-
if node.get_attr("unscaled_w"):
409+
if node.get_attr('unscaled_w'):
410410
params['scale_index_width_type'] = 'scale_index_unscaled'
411411
else:
412412
params['scale_index_width_type'] = 'scale_index_regular'
@@ -446,12 +446,12 @@ def format(self, node):
446446
params['instructions'] = '0'
447447
params['fill_fn'] = 'FillConv2DBuffer'
448448

449-
if node.get_attr("unscaled_h"):
449+
if node.get_attr('unscaled_h'):
450450
params['scale_index_height_type'] = 'scale_index_unscaled'
451451
else:
452452
params['scale_index_height_type'] = 'scale_index_regular'
453453

454-
if node.get_attr("unscaled_w"):
454+
if node.get_attr('unscaled_w'):
455455
params['scale_index_width_type'] = 'scale_index_unscaled'
456456
else:
457457
params['scale_index_width_type'] = 'scale_index_regular'

hls4ml/model/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ class Layer:
5656
ConfigurableAttribute('trace', default=False),
5757
TypeAttribute('result'),
5858
]
59-
""""""
6059

6160
@classproperty
6261
def expected_attributes(cls):
@@ -1331,8 +1330,10 @@ def initialize(self):
13311330
'QConv2D': Conv2D,
13321331
'QConv2DBatchnorm': Conv2DBatchnorm,
13331332
'SeparableConv1D': SeparableConv1D,
1333+
'QSeparableConv1D': SeparableConv1D,
13341334
'DepthwiseConv1D': DepthwiseConv1D,
13351335
'SeparableConv2D': SeparableConv2D,
1336+
'QSeparableConv2D': SeparableConv2D,
13361337
'DepthwiseConv2D': DepthwiseConv2D,
13371338
'QDepthwiseConv2D': DepthwiseConv2D,
13381339
'BatchNormalization': BatchNormalization,

0 commit comments

Comments
 (0)