Skip to content

Commit b11adf7

Browse files
dvdskosyvokon
andcommitted
Adds GlobalAveragePool onnx operator
Only 2d GlobalAveragePool is supported as we use avg_pool2d internally. Co-authored-by: Oleksiy <[email protected]>
1 parent b42c580 commit b11adf7

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

candle-onnx/src/eval.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,25 @@ fn simple_eval_(
494494
};
495495
values.insert(node.output[0].clone(), ys);
496496
}
497+
"GlobalAveragePool" => {
498+
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool
499+
let xs = get(&node.input[0])?;
500+
let [n_dim, c_dim, kernel_shape @ ..] = xs.dims() else {
501+
bail!(
502+
"only 2d GlobalAveragePool is supported, kernel shape {:?}",
503+
xs.dims()
504+
);
505+
};
506+
let ys = match kernel_shape {
507+
[d1, d2] => xs.avg_pool2d((*d1, *d2)),
508+
[d1] => {
509+
let xs = xs.unsqueeze(1)?;
510+
xs.avg_pool2d((1, *d1))
511+
}
512+
_ => todo!(),
513+
}?;
514+
values.insert(node.output[0].clone(), ys);
515+
}
497516
"AveragePool" => {
498517
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
499518
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;

0 commit comments

Comments
 (0)