@@ -407,10 +407,7 @@ def _(layer: Conv1D | Conv2D):
407407@_produce_kif .register (GlobalPooling1D )
408408@_produce_kif .register (GlobalPooling2D )
409409def _ (layer : Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D ):
410- if isinstance (layer , (Pooling1D , GlobalPooling1D )):
411- px_shape = (layer .attributes ['pool_width' ],)
412- else :
413- px_shape = (layer .attributes ['pool_height' ], layer .attributes ['pool_width' ])
410+ px_shape = _get_px_shape (layer )
414411 ch_out = ch_in = layer .attributes ['n_filt' ]
415412
416413 im2col_shape = * px_shape , ch_in , ch_out # conv kernel shape
@@ -432,6 +429,8 @@ def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
432429 raise ValueError ('Average pooling with non-power-of-2 pool size cannot be bit-exact' )
433430 f_out += int (f_add )
434431
432+ if isinstance (layer , (GlobalPooling1D , GlobalPooling2D )):
433+ k_out , i_out , f_out = k_out [0 ], i_out [0 ], f_out [0 ]
435434 return k_out , i_out , f_out
436435
437436
@@ -665,6 +664,22 @@ def _(node: UnaryLUT):
665664 default_register_precision (node )
666665
667666
667+ def _get_px_shape (node : Layer ):
668+ if isinstance (node , Pooling1D ):
669+ px_shape = (node .attributes ['pool_width' ],)
670+ elif isinstance (node , GlobalPooling1D ):
671+ inp_shape = get_input_shapes (node )[0 ]
672+ px_shape = (inp_shape [0 ],)
673+ elif isinstance (node , Pooling2D ):
674+ px_shape = (node .attributes ['pool_height' ], node .attributes ['pool_width' ])
675+ elif isinstance (node , GlobalPooling2D ):
676+ inp_shape = get_input_shapes (node )[0 ]
677+ px_shape = (inp_shape [0 ], inp_shape [1 ])
678+ else :
679+ raise ValueError (f'Layer { node .class_name } is not supported for pooling precision derivation' )
680+ return px_shape
681+
682+
668683@register_precision .register (Pooling1D )
669684@register_precision .register (Pooling2D )
670685@register_precision .register (GlobalPooling1D )
@@ -674,10 +689,7 @@ def _(node: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
674689 pool_op = node .attributes ['pool_op' ]
675690 if pool_op != 'Average' :
676691 return
677- if isinstance (node , (Pooling1D , GlobalPooling1D )):
678- px_shape = (node .attributes ['pool_width' ],)
679- else :
680- px_shape = (node .attributes ['pool_height' ], node .attributes ['pool_width' ])
692+ px_shape = _get_px_shape (node )
681693 i_add = int (log2 (prod (px_shape )))
682694 node .attributes ['accum_t' ].precision .width += i_add
683695 node .attributes ['accum_t' ].precision .integer += i_add
0 commit comments