File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -494,6 +494,25 @@ fn simple_eval_(
494494 } ;
495495 values. insert ( node. output [ 0 ] . clone ( ) , ys) ;
496496 }
497+ "GlobalAveragePool" => {
498+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool
499+ let xs = get ( & node. input [ 0 ] ) ?;
500+ let [ n_dim, c_dim, kernel_shape @ ..] = xs. dims ( ) else {
501+ bail ! (
502+ "only 2d GlobalAveragePool is supported, kernel shape {:?}" ,
503+ xs. dims( )
504+ ) ;
505+ } ;
506+ let ys = match kernel_shape {
507+ [ d1, d2] => xs. avg_pool2d ( ( * d1, * d2) ) ,
508+ [ d1] => {
509+ let xs = xs. unsqueeze ( 1 ) ?;
510+ xs. avg_pool2d ( ( 1 , * d1) )
511+ }
512+ _ => todo ! ( ) ,
513+ } ?;
514+ values. insert ( node. output [ 0 ] . clone ( ) , ys) ;
515+ }
497516 "AveragePool" => {
498517 // https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
499518 let dilations = get_attr_opt :: < [ i64 ] > ( node, "dilations" ) ?;
You can’t perform that action at this time.
0 commit comments