11import typing
22from copy import copy
33from functools import reduce , singledispatch
4- from math import ceil , log2
4+ from math import ceil , log2 , prod
55from typing import Sequence
66from warnings import warn
77
1717 Einsum ,
1818 EinsumDense ,
1919 GlobalPooling1D ,
20+ GlobalPooling2D ,
2021 Input ,
2122 Layer ,
2223 Merge ,
2324 Pooling1D ,
25+ Pooling2D ,
2426 Reshape ,
2527 Softmax ,
2628)
@@ -101,52 +103,6 @@ def _(layer: FixedPointQuantizer):
101103 return ((k , i , f ),)
102104
103105
104- @request_kif .register (Pooling1D )
105- # @request_kif.register(Pooling2D)
106- @request_kif .register (GlobalPooling1D )
107- # @request_kif.register(GlobalPooling2D)
108- def _ (layer : Pooling1D | GlobalPooling1D ):
109- # inp_shape = get_input_shapes(layer)[0]
110- out_shape = get_output_shape (layer )
111- pool_width = layer .attributes .attributes ['pool_width' ]
112- stride_width = layer .attributes .attributes ['stride_width' ]
113- pool_op = layer .attributes .attributes ['pool_op' ]
114- if isinstance (layer , Pooling1D ):
115- pad_0_0 : int = layer .attributes .attributes ['pad_left' ]
116- else :
117- pad_0_0 = 0
118- is_ch_last = layer .attributes .attributes ['data_format' ] == 'channels_last'
119-
120- k = np .ones (out_shape , dtype = np .int8 )
121- i = np .full (out_shape , - 127 , dtype = np .int8 )
122- f = np .full (out_shape , 126 , dtype = np .int8 )
123-
124- _ , i_out , f_out = requested_kif (layer )
125-
126- if not is_ch_last :
127- i = np .moveaxis (i , 0 , - 1 )
128- f = np .moveaxis (f , 0 , - 1 )
129-
130- for idx_out in range (k .shape [- 1 ]):
131- i_in_0 = i_out * stride_width - pad_0_0
132- i_in_1 = i_in_0 + pool_width
133- if i_in_0 < 0 :
134- i_in_0 = 0
135- i [..., i_in_0 :i_in_1 ] = i_out [..., idx_out ]
136- f [..., i_in_0 :i_in_1 ] = f_out [..., idx_out ]
137-
138- if not is_ch_last :
139- i = np .moveaxis (i , - 1 , 0 )
140- f = np .moveaxis (f , - 1 , 0 )
141-
142- if pool_op == 'Average' :
143- ln2_size = np .log2 (pool_width )
144- i += np .ceil (ln2_size ).astype (np .int8 )
145- if not ln2_size .is_integer ():
146- f [:] = 126
147- return ((k , i , f ),)
148-
149-
150106@request_kif .register
151107def _ (layer : Reshape ):
152108 inp_shape = get_input_shapes (layer )[0 ]
@@ -332,15 +288,15 @@ def im2col(kernel_size: Sequence[int], *arrs: np.ndarray):
332288
333289def pad_arrs (node : Layer , pad_val : float = 0 , * arrs : np .ndarray ):
334290 out_arrs = []
335- if node .class_name .endswith ('Conv2D ' ):
291+ if node .class_name .endswith ('2D ' ):
336292 pad_top = node .attributes .attributes ['pad_top' ]
337293 pad_bottom = node .attributes .attributes ['pad_bottom' ]
338294 pad_left = node .attributes .attributes ['pad_left' ]
339295 pad_right = node .attributes .attributes ['pad_right' ]
340296 for arr in arrs :
341297 r = np .pad (arr , ((pad_top , pad_bottom ), (pad_left , pad_right ), (0 , 0 )), constant_values = pad_val )
342298 out_arrs .append (r )
343- elif node .class_name .endswith ('Conv1D ' ):
299+ elif node .class_name .endswith ('1D ' ):
344300 pad_left = node .attributes .attributes ['pad_left' ]
345301 pad_right = node .attributes .attributes ['pad_right' ]
346302 for arr in arrs :
@@ -352,11 +308,11 @@ def pad_arrs(node: Layer, pad_val: float = 0, *arrs: np.ndarray):
352308
353309
354310def stride_arrs (node : Layer , * arrs : np .ndarray ):
355- if node .class_name .endswith ('Conv2D ' ):
311+ if node .class_name .endswith ('2D ' ):
356312 st_h = node .attributes .attributes ['stride_height' ]
357313 st_w = node .attributes .attributes ['stride_width' ]
358314 return tuple (arr [::st_h , ::st_w ] for arr in arrs )
359- if node .class_name .endswith ('Conv1D ' ):
315+ if node .class_name .endswith ('1D ' ):
360316 st_w = node .attributes .attributes ['stride_width' ]
361317 return tuple (arr [::st_w ] for arr in arrs )
362318 raise ValueError (f'Layer { node .class_name } is not supported for stride_arrs' )
@@ -365,6 +321,7 @@ def stride_arrs(node: Layer, *arrs: np.ndarray):
365321@produce_kif .register (Conv1D )
366322@produce_kif .register (Conv2D )
367323def _ (layer : Conv1D | Conv2D ):
324+ assert layer .attributes .attributes ['data_format' ] == 'channels_last' , 'Only channels_last format is supported'
368325 kernel = layer .attributes .attributes ['weight' ].data
369326 _bias = layer .attributes .attributes ['bias' ]
370327 bias = _bias .data if _bias is not None else 0
@@ -380,6 +337,39 @@ def _(layer: Conv1D | Conv2D):
380337 return k .astype (np .int8 ), i , f
381338
382339
340+ @produce_kif .register (Pooling1D )
341+ @produce_kif .register (Pooling2D )
342+ @produce_kif .register (GlobalPooling1D )
343+ @produce_kif .register (GlobalPooling2D )
344+ def _ (layer : Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D ):
345+ if isinstance (layer , (Pooling1D , GlobalPooling1D )):
346+ px_shape = (layer .attributes ['pool_width' ],)
347+ else :
348+ px_shape = (layer .attributes ['pool_height' ], layer .attributes ['pool_width' ])
349+ ch_out = ch_in = layer .attributes ['n_filt' ]
350+
351+ im2col_shape = * px_shape , ch_in , ch_out # conv kernel shape
352+ k_in , i_in , f_in = get_input_kifs (layer )[0 ]
353+ if isinstance (layer , (Pooling1D , Pooling2D )):
354+ k_in , i_in , f_in = pad_arrs (layer , 0 , k_in , i_in , f_in )
355+ k_in , i_in , f_in = im2col (im2col_shape , k_in , i_in , f_in )
356+ if isinstance (layer , (Pooling1D , Pooling2D )):
357+ k_in , i_in , f_in = stride_arrs (layer , k_in , i_in , f_in )
358+
359+ k_out = k_in .reshape (* k_in .shape [:- 1 ], - 1 , ch_in ).max (axis = - 2 ).astype (np .int8 )
360+ i_out = i_in .reshape (* i_in .shape [:- 1 ], - 1 , ch_in ).max (axis = - 2 ).astype (np .int8 )
361+ f_out = f_in .reshape (* f_in .shape [:- 1 ], - 1 , ch_in ).max (axis = - 2 ).astype (np .int8 )
362+
363+ pool_op = layer .attributes ['pool_op' ]
364+ if pool_op == 'Average' :
365+ f_add = log2 (prod (px_shape ))
366+ if not f_add .is_integer ():
367+ raise ValueError ('Average pooling with non-power-of-2 pool size cannot be bit-exact' )
368+ f_out += int (f_add )
369+
370+ return k_out , i_out , f_out
371+
372+
383373@produce_kif .register
384374def _ (layer : BatchNormalization ):
385375 k_in , i_in , f_in = get_input_kifs (layer )[0 ]
0 commit comments