@@ -244,10 +244,12 @@ def __init__(self):
244244}};\n """
245245
246246sepconv1d_function_template = (
247- 'nnet::separable_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'
247+ 'nnet::separable_conv_1d_{data_format}<{input_t}, {dw_output_t}, {output_t}, {config}>('
248+ '{input}, {output}, {d}, {p}, {z}, {b});'
248249)
249250sepconv2d_function_template = (
250- 'nnet::separable_conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'
251+ 'nnet::separable_conv_2d_{data_format}<{input_t}, {dw_output_t}, {output_t}, {config}>('
252+ '{input}, {output}, {d}, {p}, {z}, {b});'
251253)
252254
253255sepconv1d_include_list = ['nnet_utils/nnet_conv1d.h' , 'nnet_utils/nnet_sepconv1d_stream.h' ]
@@ -273,14 +275,17 @@ def format(self, node):
273275
274276 # Depthwise config
275277 params = self ._default_config_params (node )
278+ # Override bias and bias_t since these are zeros in depthwise step of SepConv1D
279+ params ['bias' ] = params ['zero_bias' ]
280+ params ['bias_t' ] = params ['zero_bias_t' ]
276281 params ['n_filt' ] = params ['n_chan' ] # In depthwise step n_chan == n_filt
277282 params ['dilation' ] = node .get_attr ('dilation' , 1 )
278283 params ['nzeros' ] = node .get_weights ('depthwise' ).nzeros
279284 params ['index' ] = str (node .index ) + '_depthwise'
280285 params ['weight_t' ] = node .get_weights ('depthwise' ).type
281286 params ['fill_fn' ] = 'FillConv1DBuffer'
282287
283- if node .get_attr (" unscaled" ):
288+ if node .get_attr (' unscaled' ):
284289 params ['scale_index_type' ] = 'scale_index_unscaled'
285290 else :
286291 params ['scale_index_type' ] = 'scale_index_regular'
@@ -301,14 +306,11 @@ def format(self, node):
301306 depthwise_mult_config = self .depthwise_mult_template .format (** mult_params )
302307
303308 # Pointwise config
304- params = self ._default_config_params ()
305- input_shape = self .get_input_variable ().shape
306- if self .get_attr ('data_format' ) == 'channels_last' :
307- params ['in_width' ] = '*' .join ([str (k ) for k in input_shape [:- 1 ]])
308- params ['n_chan' ] = input_shape [- 1 ]
309+ params = self ._default_config_params (node )
310+ if node .get_attr ('data_format' ) == 'channels_last' :
311+ params ['in_width' ] = node .get_output_variable ().shape [0 ]
309312 else :
310- params ['in_width' ] = '*' .join ([str (k ) for k in input_shape [1 :]])
311- params ['n_chan' ] = input_shape [0 ]
313+ params ['in_width' ] = node .get_output_variable ().shape [1 ]
312314
313315 params ['filt_width' ] = 1
314316 params ['stride_width' ] = 1
@@ -320,7 +322,7 @@ def format(self, node):
320322 params ['instructions' ] = '0'
321323 params ['fill_fn' ] = 'FillConv1DBuffer'
322324
323- if node .get_attr (" unscaled" ):
325+ if node .get_attr (' unscaled' ):
324326 params ['scale_index_type' ] = 'scale_index_unscaled'
325327 else :
326328 params ['scale_index_type' ] = 'scale_index_regular'
@@ -360,6 +362,7 @@ def __init__(self):
360362
361363 def format (self , node ):
362364 params = self ._default_function_params (node )
365+ params ['dw_output_t' ] = node .get_attr ('dw_output_t' ).name
363366 params ['data_format' ] = 'cf' if node .get_attr ('data_format' ) == 'channels_first' else 'cl'
364367 params ['d' ] = node .get_weights ('depthwise' ).name
365368 params ['p' ] = node .get_weights ('pointwise' ).name
@@ -398,12 +401,12 @@ def format(self, node):
398401 params ['weight_t' ] = node .get_weights ('depthwise' ).type
399402 params ['fill_fn' ] = 'FillConv2DBuffer'
400403
401- if node .get_attr (" unscaled_h" ):
404+ if node .get_attr (' unscaled_h' ):
402405 params ['scale_index_height_type' ] = 'scale_index_unscaled'
403406 else :
404407 params ['scale_index_height_type' ] = 'scale_index_regular'
405408
406- if node .get_attr (" unscaled_w" ):
409+ if node .get_attr (' unscaled_w' ):
407410 params ['scale_index_width_type' ] = 'scale_index_unscaled'
408411 else :
409412 params ['scale_index_width_type' ] = 'scale_index_regular'
@@ -443,12 +446,12 @@ def format(self, node):
443446 params ['instructions' ] = '0'
444447 params ['fill_fn' ] = 'FillConv2DBuffer'
445448
446- if node .get_attr (" unscaled_h" ):
449+ if node .get_attr (' unscaled_h' ):
447450 params ['scale_index_height_type' ] = 'scale_index_unscaled'
448451 else :
449452 params ['scale_index_height_type' ] = 'scale_index_regular'
450453
451- if node .get_attr (" unscaled_w" ):
454+ if node .get_attr (' unscaled_w' ):
452455 params ['scale_index_width_type' ] = 'scale_index_unscaled'
453456 else :
454457 params ['scale_index_width_type' ] = 'scale_index_regular'
@@ -487,6 +490,7 @@ def __init__(self):
487490
488491 def format (self , node ):
489492 params = self ._default_function_params (node )
493+ params ['dw_output_t' ] = node .get_attr ('dw_output_t' ).name
490494 params ['data_format' ] = 'cf' if node .get_attr ('data_format' ) == 'channels_first' else 'cl'
491495 params ['d' ] = node .get_weights ('depthwise' ).name
492496 params ['p' ] = node .get_weights ('pointwise' ).name
0 commit comments