@@ -275,14 +275,17 @@ def format(self, node):
275275
276276 # Depthwise config
277277 params = self ._default_config_params (node )
278+ # Override bias and bias_t since these are zeros in depthwise step of SepConv1D
279+ params ['bias' ] = params ['zero_bias' ]
280+ params ['bias_t' ] = params ['zero_bias_t' ]
278281 params ['n_filt' ] = params ['n_chan' ] # In depthwise step n_chan == n_filt
279282 params ['dilation' ] = node .get_attr ('dilation' , 1 )
280283 params ['nzeros' ] = node .get_weights ('depthwise' ).nzeros
281284 params ['index' ] = str (node .index ) + '_depthwise'
282285 params ['weight_t' ] = node .get_weights ('depthwise' ).type
283286 params ['fill_fn' ] = 'FillConv1DBuffer'
284287
285- if node .get_attr (" unscaled" ):
288+ if node .get_attr (' unscaled' ):
286289 params ['scale_index_type' ] = 'scale_index_unscaled'
287290 else :
288291 params ['scale_index_type' ] = 'scale_index_regular'
@@ -303,14 +306,11 @@ def format(self, node):
303306 depthwise_mult_config = self .depthwise_mult_template .format (** mult_params )
304307
305308 # Pointwise config
306- params = self ._default_config_params ()
307- input_shape = self .get_input_variable ().shape
308- if self .get_attr ('data_format' ) == 'channels_last' :
309- params ['in_width' ] = '*' .join ([str (k ) for k in input_shape [:- 1 ]])
310- params ['n_chan' ] = input_shape [- 1 ]
309+ params = self ._default_config_params (node )
310+ if node .get_attr ('data_format' ) == 'channels_last' :
311+ params ['in_width' ] = node .get_output_variable ().shape [0 ]
311312 else :
312- params ['in_width' ] = '*' .join ([str (k ) for k in input_shape [1 :]])
313- params ['n_chan' ] = input_shape [0 ]
313+ params ['in_width' ] = node .get_output_variable ().shape [1 ]
314314
315315 params ['filt_width' ] = 1
316316 params ['stride_width' ] = 1
@@ -322,7 +322,7 @@ def format(self, node):
322322 params ['instructions' ] = '0'
323323 params ['fill_fn' ] = 'FillConv1DBuffer'
324324
325- if node .get_attr (" unscaled" ):
325+ if node .get_attr (' unscaled' ):
326326 params ['scale_index_type' ] = 'scale_index_unscaled'
327327 else :
328328 params ['scale_index_type' ] = 'scale_index_regular'
@@ -401,12 +401,12 @@ def format(self, node):
401401 params ['weight_t' ] = node .get_weights ('depthwise' ).type
402402 params ['fill_fn' ] = 'FillConv2DBuffer'
403403
404- if node .get_attr (" unscaled_h" ):
404+ if node .get_attr (' unscaled_h' ):
405405 params ['scale_index_height_type' ] = 'scale_index_unscaled'
406406 else :
407407 params ['scale_index_height_type' ] = 'scale_index_regular'
408408
409- if node .get_attr (" unscaled_w" ):
409+ if node .get_attr (' unscaled_w' ):
410410 params ['scale_index_width_type' ] = 'scale_index_unscaled'
411411 else :
412412 params ['scale_index_width_type' ] = 'scale_index_regular'
@@ -446,12 +446,12 @@ def format(self, node):
446446 params ['instructions' ] = '0'
447447 params ['fill_fn' ] = 'FillConv2DBuffer'
448448
449- if node .get_attr (" unscaled_h" ):
449+ if node .get_attr (' unscaled_h' ):
450450 params ['scale_index_height_type' ] = 'scale_index_unscaled'
451451 else :
452452 params ['scale_index_height_type' ] = 'scale_index_regular'
453453
454- if node .get_attr (" unscaled_w" ):
454+ if node .get_attr (' unscaled_w' ):
455455 params ['scale_index_width_type' ] = 'scale_index_unscaled'
456456 else :
457457 params ['scale_index_width_type' ] = 'scale_index_regular'
0 commit comments