Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
c3470eb
change nnet::array<apfix, N> to apfix
ChiRuiChen Aug 6, 2022
2e465c1
Add several conditions for resuse factors
ChiRuiChen Aug 6, 2022
f9dac81
Set stream vairables to depth=1
ChiRuiChen Aug 6, 2022
24beaaf
Add functions for single stream
ChiRuiChen Aug 6, 2022
004401f
Use functions for single stream
ChiRuiChen Aug 6, 2022
8b15114
add ss for linear, relu, sigmoid
ChiRuiChen Aug 6, 2022
41958fc
add ss for softmax
ChiRuiChen Aug 6, 2022
491fe56
add ss for leaky_relu
ChiRuiChen Aug 6, 2022
abbe598
add ss for batchnorm
ChiRuiChen Aug 6, 2022
3271abf
Update nnet_batchnorm_stream.h
ChiRuiChen Aug 10, 2022
2de6090
Update nnet_batchnorm_stream.h
ChiRuiChen Aug 10, 2022
5b16b5c
Conv2d Single Stream
ChiRuiChen Aug 10, 2022
4137c1c
Encoded SS
ChiRuiChen Aug 10, 2022
252b3c5
Update nnet_dense_resource.h
ChiRuiChen Aug 10, 2022
e62a9c6
Add Dense_ss with rf
ChiRuiChen Aug 11, 2022
37b2889
Upsampling2D ss
ChiRuiChen Aug 11, 2022
9644812
Single Stream for ZeroPadding2D
ChiRuiChen Aug 11, 2022
63c26a3
MaxPooling SS
ChiRuiChen Aug 11, 2022
f4741fe
add ss profix for conv2d
ChiRuiChen Aug 11, 2022
81966ee
add ss profix for dense and activations
ChiRuiChen Aug 11, 2022
a7ad6ae
ss profix for pooling2d
ChiRuiChen Aug 11, 2022
89dc62a
ss profix for zeropadding2d upsampling2d
ChiRuiChen Aug 11, 2022
355aed0
add rf constraint is dense_ss
ChiRuiChen Aug 11, 2022
ab0b948
input_precision typo
ChiRuiChen Aug 11, 2022
9e6255b
copy_data_ss type
ChiRuiChen Aug 11, 2022
111b243
convert_data_ss typo
ChiRuiChen Aug 11, 2022
7f71d6a
pooling2d_cl_ss rename
ChiRuiChen Aug 11, 2022
e599e05
conv_2d_buffer_cl_ss rename
ChiRuiChen Aug 11, 2022
ed851df
compute_output_encoded_ss typo
ChiRuiChen Aug 11, 2022
ca00c9c
use compute_scaled_indices_2d_ss
ChiRuiChen Aug 11, 2022
9c9f239
don't print out dense type
ChiRuiChen Aug 11, 2022
ff1f796
set rf to n_out in Dense_ss, set rf to n_in in other ss layers
ChiRuiChen Aug 11, 2022
bd41078
typo
ChiRuiChen Aug 11, 2022
92e5d8a
use _rf == n_in
ChiRuiChen Aug 11, 2022
c084db3
typo
ChiRuiChen Aug 11, 2022
edb5476
fix dense_ss rf condition
ChiRuiChen Aug 11, 2022
ac82625
add convert_data_ss for my_bridge
ChiRuiChen Aug 11, 2022
759a143
change loops in axi.cpp to ss
ChiRuiChen Aug 11, 2022
ae9e13e
Merge branch 'fastmachinelearning:main' into Single-stream
ChiRuiChen Aug 15, 2022
d9d9dd6
copy_data_axi no need profix ss
ChiRuiChen Aug 15, 2022
ddc20dd
add print_result ss
ChiRuiChen Aug 15, 2022
e0ddc61
dummy axi_lite_driver.py
ChiRuiChen Aug 15, 2022
4feb0ab
dummy axi_lite_driver.py
ChiRuiChen Aug 15, 2022
85ef02f
dummy axi_lite_driver.py
ChiRuiChen Aug 15, 2022
384fa8b
add deepcalo_layers into supported_layers
ChiRuiChen Aug 19, 2022
c46efda
remove out_hieght, out_width in GAP
ChiRuiChen Aug 30, 2022
92cdb3e
Add Global Avg Pool SS
ChiRuiChen Aug 31, 2022
fd8ea5f
fix normalize_ss product template
ChiRuiChen Sep 3, 2022
5d2aad1
typo
ChiRuiChen Sep 3, 2022
69ac048
add depthwise_conv_2d_ss
ChiRuiChen Sep 16, 2022
66af216
add DepthwiseConv2D support
ChiRuiChen Sep 16, 2022
88e6630
add profix ss for depthwise_conv_2d
ChiRuiChen Sep 16, 2022
4c90431
add prelu ss
ChiRuiChen Sep 16, 2022
6b807f1
add ss for pointwise_conv2d
ChiRuiChen Sep 18, 2022
b9aa92e
add ss profix for seperable conv2d
ChiRuiChen Sep 18, 2022
03b9686
add ss profix for pointwise_conv2d
ChiRuiChen Sep 18, 2022
795a04b
Update nnet_sepconv2d_stream.h
ChiRuiChen Sep 18, 2022
057b9ee
use model_default_t for depthwise bias in SeparableConv2D
ChiRuiChen Sep 18, 2022
33412a3
add ss for SeperableConv2d
ChiRuiChen Sep 18, 2022
8c5af48
fix typo in prelu_ss
ChiRuiChen Sep 18, 2022
b48c0ae
fix typo separable_conv_2d_cl_ss
ChiRuiChen Sep 18, 2022
17b7484
add clone_stream_ss
ChiRuiChen Sep 18, 2022
af052fa
add ss for add_ss
ChiRuiChen Sep 18, 2022
f5704b5
add ss profix for merge
ChiRuiChen Sep 18, 2022
ef89571
add ss profix for clone_stream
ChiRuiChen Sep 18, 2022
e5ccf18
fix DepthwiseConv2D weights
ChiRuiChen Sep 19, 2022
cc2b495
add DepthwiseConv2D
ChiRuiChen Sep 19, 2022
79f7114
add DepthwiseConv2D
ChiRuiChen Sep 19, 2022
40f313d
add DepthwiseConv2D
ChiRuiChen Sep 19, 2022
0d70ff6
use 2 rf in SeparableConv2D
ChiRuiChen Sep 19, 2022
0826bae
add 2 rf Attribute in SeparableConv2D
ChiRuiChen Sep 19, 2022
7852c3c
2 rf for SeparableConv2D
ChiRuiChen Sep 19, 2022
cbb56c2
fix shift_right_small data_T
ChiRuiChen Sep 23, 2022
d6887c4
total accum_t bits in Dense, Con2D
ChiRuiChen Sep 25, 2022
8def5dd
fix non trainable error in BatchNorm
ChiRuiChen Sep 25, 2022
e73643b
add profix ss for global_pooling2d
ChiRuiChen Sep 26, 2022
7f52e31
use accum_t in Merge Layer
ChiRuiChen Sep 26, 2022
5c0276c
add accum_t for Merge Config
ChiRuiChen Sep 26, 2022
b5b1cc4
use accum_t in add
ChiRuiChen Sep 26, 2022
a33c8cc
fix stride skip in pointwise
ChiRuiChen Sep 26, 2022
0ef42f4
use accum_t for GAP2D
ChiRuiChen Oct 11, 2022
0897013
add QDenseBatchnorm
ChiRuiChen Nov 2, 2022
cb98671
merge DenseBatchnorm into Dense
ChiRuiChen Nov 2, 2022
b4fd141
add QDenseBatchnorm into BatchNormalization
ChiRuiChen Nov 2, 2022
8926cff
add QDenseBatchnorm
ChiRuiChen Nov 2, 2022
e27c09f
add DenseBatchnorm
ChiRuiChen Nov 2, 2022
d360244
add (Dense, DenseBatchnorm)
ChiRuiChen Nov 2, 2022
188017c
fix n_in in densebatchnorm
ChiRuiChen Nov 2, 2022
569ebbd
use softmax legacy as default
ChiRuiChen Dec 7, 2022
4b8b42d
Merge branch 'fastmachinelearning:main' into Single-stream
ChiRuiChen Dec 11, 2022
2f6dfd5
add deecalo layers
ChiRuiChen Dec 13, 2022
04ee229
converter for deepcalo layers
ChiRuiChen Dec 13, 2022
9eda3cb
deepcalo layers templates
ChiRuiChen Dec 13, 2022
a3da57f
add ss profix for pointwise_conv_1d
ChiRuiChen Dec 13, 2022
374d3cb
ss profix for separable_conv_1d
ChiRuiChen Dec 13, 2022
253ca7d
add deepcalo layers
ChiRuiChen Dec 13, 2022
4768aa5
stream to depth
ChiRuiChen Dec 13, 2022
152b8a6
deepcalo header files
ChiRuiChen Dec 13, 2022
a09046a
fix rf is DIV_ROUNDUP(n_out, reuse_factor) is 1
ChiRuiChen Dec 13, 2022
a726846
add pointwise_conv_1d_cl_ss
ChiRuiChen Dec 13, 2022
1dd0f3a
add pointwise_mult_buffer_ss
ChiRuiChen Dec 13, 2022
6a5ddfc
use accum_t in multiply_ss
ChiRuiChen Dec 13, 2022
674dba8
typo
ChiRuiChen Dec 13, 2022
0ebdb80
add concatenate1d_ss
ChiRuiChen Dec 13, 2022
9d681d5
rf based on nout for Dense_ss and TimeDistributed
ChiRuiChen Dec 14, 2022
fb38175
rf based on nout for Dense_ss and TimeDistributed
ChiRuiChen Dec 14, 2022
5a295bf
_rf <= n_in
ChiRuiChen Dec 14, 2022
624800e
2 rf for TimeDistributed layer
ChiRuiChen Dec 14, 2022
3dea0ad
import TimeDistributed
ChiRuiChen Dec 14, 2022
31077b5
rf for dense_ss
ChiRuiChen Dec 14, 2022
9db03b1
rf == n_in
ChiRuiChen Dec 15, 2022
852e9de
don't round the accun_t
ChiRuiChen Dec 15, 2022
d3f7e09
fix use_bias = False in DenseBatchnorm
ChiRuiChen Dec 15, 2022
8ce813e
remove transpose in Conv2DBatchnorm
ChiRuiChen Dec 15, 2022
3c87a17
show accum bits
ChiRuiChen Dec 15, 2022
468fdb1
typo
ChiRuiChen Dec 15, 2022
b900c28
add output_t for resize layer
ChiRuiChen Jan 12, 2023
81ee8fa
add res_T for resize_ss
ChiRuiChen Jan 12, 2023
3d1eb2f
use 2d weight for dense and timedistributed
ChiRuiChen Jan 27, 2023
cbdf406
use 2d weight for timedistributed
ChiRuiChen Jan 27, 2023
573182c
use 2d weight for dense_ss
ChiRuiChen Jan 27, 2023
0dc9329
add load_2d_weight
ChiRuiChen Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 85 additions & 6 deletions hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,69 @@ def get_layer_mult_size(self, layer):
n_out_recr = n_out
return n_in, n_out, n_in_recr, n_out_recr

# 2022 12
if 'TimeDistributed' in layer.class_name:
n_in = layer.get_attr('n_in')
n_hid = layer.get_attr('n_hid')
n_out = layer.get_attr('n_out')
return n_in, n_hid, n_out

raise Exception(f'Cannot get mult size for layer {layer.name} ({layer.class_name})')

# 2022 12
# For Dense_ss , rf is chosen based on nout only---------------------------------------------
def get_valid_reuse_factors_nout(self, n_out, layer):
max_rf = n_out
valid_reuse_factors = []
for rf in range(1, max_rf + 1):
_assert = self._validate_reuse_factor_nout(n_out, rf, layer)
if _assert:
valid_reuse_factors.append(rf)
return valid_reuse_factors

def get_valid_reuse_factors(self, n_in, n_out):
def _validate_reuse_factor_nout(self, n_out, rf, layer):

# take the input_precision into account
input_precision = layer.get_input_variable().type.precision.width

_assert = (((n_out) % rf) == 0)

block_factor = int(math.ceil((n_out) / float(rf)))
_assert = _assert and (input_precision * block_factor) < 65536

return _assert

def set_closest_reuse_factor_nout(self, layer, n_out, attribute='reuse_factor'):
assert attribute is not None, 'Reuse factor attribute cannot be None'

valid_rf = self.get_valid_reuse_factors_nout(n_out, layer)
chosen_rf = layer.get_attr(attribute)

if chosen_rf not in valid_rf:
closest_rf = self.get_closest_reuse_factor(valid_rf, chosen_rf)

# 2022 CHIRUI USE 2ND MAX RF
#closest_rf = valid_rf[-2]

print('WARNING: Invalid ReuseFactor={} in layer "{}". Using ReuseFactor={} instead. Valid ReuseFactor(s): {}.'
.format(chosen_rf, layer.name, closest_rf, ','.join(map(str, valid_rf))))
layer.set_attr(attribute, closest_rf)
#-----------------------------------------------------------------------------------------------

def get_valid_reuse_factors(self, n_in, n_out, layer):
max_rf = n_in * n_out
valid_reuse_factors = []
for rf in range(1, max_rf + 1):
_assert = self._validate_reuse_factor(n_in, n_out, rf)
_assert = self._validate_reuse_factor(n_in, n_out, rf, layer)
if _assert:
valid_reuse_factors.append(rf)
return valid_reuse_factors

def _validate_reuse_factor(self, n_in, n_out, rf):
def _validate_reuse_factor(self, n_in, n_out, rf, layer):

# take the input_precision into account
input_precision = layer.get_input_variable().type.precision.width

multfactor = min(n_in, rf)
multiplier_limit = int(math.ceil((n_in * n_out) / float(multfactor)))
#
Expand All @@ -126,7 +177,30 @@ def _validate_reuse_factor(self, n_in, n_out, rf):
# THIS ASSERTION IS FOR QoR AND EXECUTION TIME
#
_assert = _assert and (((n_in * n_out) % rf) == 0)


# 2022 CHIRUI
#
# THIS ASSERTION IS FOR MAKING SURE THAT (INPUT_PRECISION * BLOCK_FACTOR) WON'T EXCEED VIVADO 65535 BITWIDTH
# IT IS USED FOR THE RESHAPE PRAGMA OF WEIGHTS AND BIAS IN THE DENSE LAYER

## get the block_factor, it is valid for first 2 kinds of the dense layers caculation
_rf = min(n_in * n_out, rf)
block_factor = int(math.ceil((n_in * n_out) / float(_rf)))
_assert = _assert and (input_precision * block_factor) < 65536

# ------------------Several Dense choices-----------------
#
# THIS ASSERTION IS FOR USING 1ST KIND OF THE DENSE LAYER
#
#_assert = _assert and _rf <= n_in
#
# THIS ASSERTION IS FOR USING THE MAX RF IN 1ST KIND OF THE DENSE LAYER
_assert = _assert and _rf == n_in
#
# THIS ASSERTION IS FOR USING 2ND KIND OF THE DENSE LAYER WITH MAX RF
#
#_assert = _assert and _rf == (n_in * n_out)

return _assert

def get_closest_reuse_factor(self, valid_rf, chosen_rf):
Expand All @@ -148,11 +222,16 @@ def get_closest_reuse_factor(self, valid_rf, chosen_rf):

def set_closest_reuse_factor(self, layer, n_in, n_out, attribute='reuse_factor'):
assert attribute is not None, 'Reuse factor attribute cannot be None'

valid_rf = self.get_valid_reuse_factors(n_in, n_out)
valid_rf = self.get_valid_reuse_factors(n_in, n_out, layer)
chosen_rf = layer.get_attr(attribute)

if chosen_rf not in valid_rf:
closest_rf = self.get_closest_reuse_factor(valid_rf, chosen_rf)

# 2022 CHIRUI USE 2ND MAX RF
#closest_rf = valid_rf[-2]

print('WARNING: Invalid ReuseFactor={} in layer "{}". Using ReuseFactor={} instead. Valid ReuseFactor(s): {}.'
.format(chosen_rf, layer.name, closest_rf, ','.join(map(str, valid_rf))))
layer.set_attr(attribute, closest_rf)
Expand Down
4 changes: 2 additions & 2 deletions hls4ml/backends/fpga/fpga_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def convert_precision(self, precision_converter):
class PackedTypeConverter(TypeDefinition, TypePrecisionConverter):
def definition_cpp(self):
n_elem_expr = '/' if self.unpack else '*'
return 'typedef nnet::array<{precision}, {n_elem}> {name};\n'.format(name=self.name, precision=self.precision.definition_cpp(), n_elem=str(self.n_elem) + n_elem_expr + str(self.n_pack))
return 'typedef {precision} {name};\n'.format(name=self.name, precision=self.precision.definition_cpp())

class HLSTypeConverter(object):
def __init__(self, precision_converter):
Expand Down Expand Up @@ -351,4 +351,4 @@ def convert(cls, weight_var):

#endregion

#endregion
#endregion
2 changes: 1 addition & 1 deletion hls4ml/backends/fpga/passes/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def format(self, node):
params['output' + str(i + 1)] = node.variables[node.outputs[i]].name

if self.template is None:
self.template = 'nnet::clone_stream<{input_t}, {output_t}, {size}>({input}, ' + \
self.template = 'nnet::clone_stream_ss<{input_t}, {output_t}, {size}>({input}, ' + \
', '.join(['{output' + str(i + 1) + '}' for i in range(len(node.outputs))]) + \
');'

Expand Down
4 changes: 2 additions & 2 deletions hls4ml/backends/vivado/passes/conv_same_pad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D
from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, DepthwiseConv2D

class InsertZeroPaddingBeforeConv1D(OptimizerPass):
name = 'insert_zero_padding_before_conv1d'
Expand Down Expand Up @@ -50,7 +50,7 @@ class InsertZeroPaddingBeforeConv2D(OptimizerPass):
name = 'insert_zero_padding_before_conv2d'

def match(self, node):
is_match = isinstance(node, (Conv2D, SeparableConv2D)) and \
is_match = isinstance(node, (Conv2D, SeparableConv2D, DepthwiseConv2D)) and \
node.get_attr('padding') == 'same' and \
node.get_attr('filt_height') != 1 and node.get_attr('filt_width') != 1
return is_match
Expand Down
4 changes: 2 additions & 2 deletions hls4ml/backends/vivado/passes/conv_stream.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from hls4ml.model.optimizer import OptimizerPass
from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D
from hls4ml.model.layers import Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, DepthwiseConv2D

class GenerateConvStreamingInstructions(OptimizerPass):
''' Generates the instructions for streaming implementation of CNNs '''
def match(self, node):
return isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D))
return isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, DepthwiseConv2D))

def transform(self, model, node):
node_class = node.__class__.__name__
Expand Down
10 changes: 6 additions & 4 deletions hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def format(self, node):
}};
const ap_uint<config{index}::filt_height * config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""

conv2d_function_template = 'nnet::conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
depthconv2d_function_template = 'nnet::depthwise_conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
conv2d_function_template = 'nnet::conv_2d_{data_format}_ss<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
depthconv2d_function_template = 'nnet::depthwise_conv_2d_{data_format}_ss<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'

conv2d_include_list = ['nnet_utils/nnet_conv2d.h', 'nnet_utils/nnet_conv2d_stream.h']

Expand Down Expand Up @@ -184,8 +184,8 @@ def __init__(self):
typedef {pointwise_config} pointwise_config;
}};\n"""

sepconv1d_function_template = 'nnet::separable_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'
sepconv2d_function_template = 'nnet::separable_conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'
sepconv1d_function_template = 'nnet::separable_conv_1d_{data_format}_ss<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'
sepconv2d_function_template = 'nnet::separable_conv_2d_{data_format}_ss<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'

sepconv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_sepconv1d_stream.h']
sepconv2d_include_list = ['nnet_utils/nnet_conv2d.h', 'nnet_utils/nnet_sepconv2d_stream.h']
Expand Down Expand Up @@ -312,6 +312,7 @@ def format(self, node):
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width')
mult_params['n_out'] = node.get_attr('n_chan')
mult_params['weight_t'] = node.get_weights('depthwise').type
mult_params['reuse'] = node.get_attr('reuse_factor_depthwise')
mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('depthwise').type.precision)
depthwise_mult_config = self.depthwise_mult_template.format(**mult_params)

Expand Down Expand Up @@ -344,6 +345,7 @@ def format(self, node):
mult_params['n_in'] = node.get_attr('n_chan')
mult_params['n_out'] = node.get_attr('n_filt')
mult_params['weight_t'] = node.get_weights('pointwise').type
mult_params['reuse'] = node.get_attr('reuse_factor_pointwise')
mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('pointwise').type.precision)
pointwise_mult_config = self.pointwise_mult_template.format(**mult_params)

Expand Down
14 changes: 7 additions & 7 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from hls4ml.backends.backend import get_backend
from hls4ml.model.layers import Activation, BatchNormalization, Dense, Embedding, PReLU, ParametrizedActivation, Softmax
from hls4ml.model.layers import Activation, BatchNormalization, Dense, DenseBatchnorm, Embedding, PReLU, ParametrizedActivation, Softmax
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate

# Dense templates
Expand All @@ -22,13 +22,13 @@
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
dense_function_template = 'nnet::dense_ss<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'

dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h', 'nnet_utils/nnet_dense_stream.h']

class DenseConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(Dense)
super().__init__((Dense, DenseBatchnorm))
self.template = dense_config_template

def format(self, node):
Expand All @@ -41,7 +41,7 @@ def format(self, node):

class DenseFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Dense, include_header=dense_include_list)
super().__init__((Dense, DenseBatchnorm), include_header=dense_include_list)
self.template = dense_function_template

def format(self, node):
Expand All @@ -67,7 +67,7 @@ def format(self, node):
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
batchnorm_function_template = 'nnet::normalize_ss<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'

batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h']

Expand Down Expand Up @@ -117,8 +117,8 @@ def format(self, node):
typedef {inv_table_t.name} inv_table_t;
}};\n"""

activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});'
activ_function_template = 'nnet::{activation}_ss<{input_t}, {output_t}, {config}>({input}, {output});'
param_activ_function_template = 'nnet::{activation}_ss<{input_t}, {output_t}, {config}>({input}, {param}, {output});'

activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h']

Expand Down
Loading