1+ from math import prod
2+
3+ import numpy as np
4+
15from hls4ml .backends .template import FunctionCallTemplate , LayerConfigTemplate
26from hls4ml .model .layers import Resize , Transpose , ZeroPadding1D , ZeroPadding2D
37
@@ -97,16 +101,64 @@ def format(self, node):
97101
98102# Transpose templates
99103
100- transpose_config_template = """struct config{index} : nnet::transpose_config {{
101- static const unsigned depth = {depth};
102- static const unsigned height = {height};
103- static const unsigned width = {width};
104- static constexpr unsigned perm[3] = {{{perm_str}}};
105- }};\n """
106104
107- transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {output_t}, {config}>({input}, {output});'
108-
109- transpose_include_list = ['nnet_utils/nnet_array.h' , 'nnet_utils/nnet_stream.h' ]
105+ transpose_include_list = ['nnet_utils/nnet_transpose.h' , 'nnet_utils/nnet_transpose_stream.h' ]
106+
107+ transpose_config_template = """struct {config_name} {{
108+ static const unsigned dims = {dims};
109+ static const unsigned N = {N};
110+ static const unsigned* const from_shape;
111+ static const unsigned* const to_shape;
112+ static const unsigned* const perm;
113+ static const unsigned* const perm_strides;
114+ }};
115+
116+ unsigned {config_name}_from_shape[{dims}] = {{{from_shape}}};
117+ unsigned {config_name}_to_shape[{dims}] = {{{to_shape}}};
118+ unsigned {config_name}_perm[{dims}] = {{{perm}}};
119+ unsigned {config_name}_perm_strides[{dims}] = {{{perm_strides}}};
120+
121+ const unsigned* const {config_name}::from_shape = {config_name}_from_shape;
122+ const unsigned* const {config_name}::to_shape = {config_name}_to_shape;
123+ const unsigned* const {config_name}::perm = {config_name}_perm;
124+ const unsigned* const {config_name}::perm_strides = {config_name}_perm_strides;
125+ """
126+
127+ transpose_function_template = 'nnet::transpose<{input_t}, {output_t}, {config_name}>({input}, {output});'
128+
129+
130+ def permute_config_gen (name : str , shape : tuple [int , ...], perm : tuple [int , ...]):
131+ """
132+ Generate a configuration string for a permute operation. Operates by mapping the output index to input input index by:
133+ - unravel the output index
134+ - map each dimension to the corresponding stride in the input tensor, sum
135+ The operation can be expressed as:
136+
137+ new_shape = tuple(shape[i] for i in perm)
138+ strides = np.cumprod((shapes[1:] + (1,))[::-1])[::-1]
139+ perm_strides = [strides[i] for i in perm]
140+ out[index] = inp[np.dot(np.unravel_index(index, new_shape), perm_strides)]
141+
142+ Args:
143+ name (str): The name of the configuration.
144+ shape (tuple[int, ...]): The shape of the input tensor.
145+ perm (tuple[int, ...]): The permutation of the dimensions.
146+
147+ Returns:
148+ str: The formatted configuration string for the permute operation.
149+ """
150+ new_shape = tuple (shape [i ] for i in perm )
151+ strides = np .cumprod ((shape [1 :] + (1 ,))[::- 1 ])[::- 1 ]
152+ perm_strides = tuple (int (strides [i ]) for i in perm )
153+ return transpose_config_template .format (
154+ dims = len (shape ),
155+ N = prod (shape ),
156+ from_shape = ', ' .join (str (x ) for x in shape ),
157+ perm = ', ' .join (str (x ) for x in perm ),
158+ perm_strides = ', ' .join (str (x ) for x in perm_strides ),
159+ to_shape = ', ' .join (str (x ) for x in new_shape ),
160+ config_name = name ,
161+ )
110162
111163
112164class TransposeConfigTemplate (LayerConfigTemplate ):
@@ -115,18 +167,18 @@ def __init__(self):
115167 self .template = transpose_config_template
116168
117169 def format (self , node ):
118- params = self ._default_config_params (node )
119-
120- return self .template .format (** params )
170+ shape = tuple (node .get_input_variable ().shape )
171+ perm = tuple (node .get_attr ('perm' ))
172+ name = f'config{ node .index } '
173+ return permute_config_gen (name , shape , perm )
121174
122175
123176class TransposeFunctionTemplate (FunctionCallTemplate ):
124177 def __init__ (self ):
125- super ().__init__ (Transpose , include_header = transpose_include_list )
126178 self .template = transpose_function_template
179+ super ().__init__ (Transpose , include_header = transpose_include_list )
127180
128181 def format (self , node ):
129182 params = self ._default_function_params (node )
130- params ['dim' ] = node .get_attr ('dim' )
131-
183+ params ['config_name' ] = f'config{ node .index } '
132184 return self .template .format (** params )
0 commit comments