|
| 1 | +from hls4ml.backends.backend import get_backend |
| 2 | +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate |
| 3 | +from hls4ml.model.layers import EinsumDense |
| 4 | + |
| 5 | +from .reshaping_templates import transpose_config_gen |
| 6 | + |
| 7 | +# Shared Dense template |
| 8 | + |
| 9 | +conv_dense_config_template = """struct config{index}_dense : nnet::dense_config {{ |
| 10 | + static const unsigned n_in = {n_in}; |
| 11 | + static const unsigned n_out = {n_out}; |
| 12 | + static const unsigned reuse_factor = {reuse}; |
| 13 | + static const unsigned strategy = nnet::{strategy}; |
| 14 | + static const unsigned n_zeros = {nzeros}; |
| 15 | + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; |
| 16 | + typedef {accum_t.name} accum_t; |
| 17 | + typedef {bias_t.name} bias_t; |
| 18 | + typedef {weight_t.name} weight_t; |
| 19 | + template<class data_T, class res_T, class CONFIG_T> |
| 20 | + using kernel = nnet::{dense_function}<data_T, res_T, CONFIG_T>; |
| 21 | + template<class x_T, class y_T> |
| 22 | + using product = nnet::product::{product_type}<x_T, y_T>; |
| 23 | +}};\n""" |
| 24 | + |
| 25 | +# EinsumDense template |
| 26 | + |
| 27 | +einsum_dense_config_template = ''' |
| 28 | +struct config{index} {{ |
| 29 | + typedef config{index}_tpose_inp tpose_inp_conf; |
| 30 | + typedef config{index}_tpose_out tpose_out_conf; |
| 31 | + typedef config{index}_dense dense_conf; |
| 32 | +
|
| 33 | + // Layer Sizes |
| 34 | + static const unsigned n_free_data = {n_free_data}; |
| 35 | + static const unsigned n_free_kernel = {n_free_kernel}; |
| 36 | + static const unsigned n_contract = {n_contract}; |
| 37 | + static const unsigned n_inplace = {n_inplace}; |
| 38 | +
|
| 39 | + // Resource reuse info |
| 40 | + static const unsigned io_type = nnet::{iotype}; |
| 41 | + static const unsigned strategy = nnet::{strategy}; |
| 42 | + static const unsigned reuse_factor = {reuse_factor}; |
| 43 | + static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1 |
| 44 | + static const bool store_weights_in_bram = false; // NOT USED |
| 45 | +}}; |
| 46 | +''' |
| 47 | + |
| 48 | +einsum_dense_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});' |
| 49 | + |
| 50 | +einsum_dense_include_list = ['nnet_utils/nnet_einsum_dense.h', 'nnet_utils/nnet_dense.h'] |
| 51 | + |
| 52 | + |
| 53 | +class EinsumDenseConfigTemplate(LayerConfigTemplate): |
| 54 | + def __init__(self): |
| 55 | + super().__init__(EinsumDense) |
| 56 | + self.template = einsum_dense_config_template |
| 57 | + self.dense_template = conv_dense_config_template |
| 58 | + |
| 59 | + def format(self, node: EinsumDense): |
| 60 | + default_params = self._default_config_params(node) |
| 61 | + |
| 62 | + strategy = node.model.config.get_strategy(node) |
| 63 | + io_type = node.model.config.get_config_value('IOType') |
| 64 | + |
| 65 | + assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now' |
| 66 | + assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now' |
| 67 | + |
| 68 | + # EinsumDense config |
| 69 | + params = default_params.copy() |
| 70 | + params['strategy'] = strategy |
| 71 | + params['n_free_data'] = node.attributes.attributes['n_free_data'] |
| 72 | + params['n_free_kernel'] = node.attributes.attributes['n_free_kernel'] |
| 73 | + params['n_contract'] = node.attributes.attributes['n_contract'] |
| 74 | + params['n_inplace'] = node.attributes.attributes['n_inplace'] |
| 75 | + params['parallelization_factor'] = node.attributes.attributes['parallelization_factor'] |
| 76 | + |
| 77 | + einsum_conf = self.template.format(**params) |
| 78 | + |
| 79 | + # inp/out transpose config |
| 80 | + inp_shape = node.attributes.attributes['inp_shape'] |
| 81 | + out_interpert_shape = node.attributes.attributes['out_interpert_shape'] |
| 82 | + inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs'] |
| 83 | + out_tpose_idxs = node.attributes.attributes['out_tpose_idxs'] |
| 84 | + tpose_inp_conf_name = f'config{node.index}_tpose_inp' |
| 85 | + tpose_out_conf_name = f'config{node.index}_tpose_out' |
| 86 | + |
| 87 | + inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs) |
| 88 | + out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs) |
| 89 | + |
| 90 | + # Dense config |
| 91 | + dense_params = default_params.copy() |
| 92 | + dense_params['strategy'] = strategy |
| 93 | + dense_params['n_in'] = node.attributes.attributes['n_contract'] |
| 94 | + dense_params['n_out'] = node.attributes.attributes['n_free_kernel'] |
| 95 | + if node.attributes.attributes['n_inplace'] == 1: |
| 96 | + dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore |
| 97 | + else: |
| 98 | + dense_params['nzeros'] = '-1; // Not making sense when kernels are switching' |
| 99 | + dense_params['product_type'] = get_backend('vivado').product_type( |
| 100 | + node.get_input_variable().type.precision, node.get_weights('weight').type.precision # type: ignore |
| 101 | + ) |
| 102 | + |
| 103 | + dense_params['dense_function'] = 'DenseLatency' # Latency only for now |
| 104 | + |
| 105 | + dense_config = self.dense_template.format(**dense_params) |
| 106 | + |
| 107 | + return '\n\n'.join((inp_tpose_conf, out_tpose_conf, dense_config, einsum_conf)) |
| 108 | + |
| 109 | + |
| 110 | +class EinsumDenseFunctionTemplate(FunctionCallTemplate): |
| 111 | + def __init__(self): |
| 112 | + super().__init__(EinsumDense, include_header=einsum_dense_include_list) |
| 113 | + self.template = einsum_dense_function_template |
| 114 | + |
| 115 | + def format(self, node): |
| 116 | + params = self._default_function_params(node) |
| 117 | + params['w'] = node.get_weights('weight').name |
| 118 | + params['b'] = node.get_weights('bias').name |
| 119 | + |
| 120 | + return self.template.format(**params) |
0 commit comments