@@ -67,50 +67,53 @@ Tensor& avg_pool2d_out(
6767 out);
6868
6969 ScalarType in_type = in.scalar_type ();
70- ET_SWITCH_FLOAT_TYPES_AND (Long, in_type, ctx, " avg_pool2d.out" , CTYPE, [&]() {
71- if (divisor_override.has_value ()) {
72- int64_t divisor = divisor_override.value ();
73- // If divisor_override is specified, then we don't need to use `count` in
74- // the calculation. Simply sum x / divisor to get the output.
75- apply_kernel_2d_reduce_then_map_fn<CTYPE>(
76- [](const CTYPE in_val,
77- int64_t in_idx,
78- CTYPE accum,
79- int64_t accum_idx) {
80- // Average pooling does not track indexes, so return 0 for accum_idx
81- return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
82- },
83- [divisor](const int64_t count, const CTYPE accum) {
84- return accum / static_cast <CTYPE>(divisor);
85- },
86- count_include_pad,
87- in,
88- kernel_size,
89- stride,
90- padding,
91- {},
92- out);
93- } else {
94- apply_kernel_2d_reduce_then_map_fn<CTYPE>(
95- [](const CTYPE in_val,
96- int64_t in_idx,
97- CTYPE accum,
98- int64_t accum_idx) {
99- // Average pooling does not track indexes, so return 0 for accum_idx
100- return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
101- },
102- [](const int64_t count, const CTYPE accum) {
103- return accum / static_cast <CTYPE>(count);
104- },
105- count_include_pad,
106- in,
107- kernel_size,
108- stride,
109- padding,
110- {},
111- out);
112- }
113- });
70+ ET_SWITCH_FLOATHBF16_TYPES_AND (
71+ Long, in_type, ctx, " avg_pool2d.out" , CTYPE, [&]() {
72+ if (divisor_override.has_value ()) {
73+ int64_t divisor = divisor_override.value ();
74+ // If divisor_override is specified, then we don't need to use `count`
75+ // in the calculation. Simply sum x / divisor to get the output.
76+ apply_kernel_2d_reduce_then_map_fn<CTYPE>(
77+ [](const CTYPE in_val,
78+ int64_t in_idx,
79+ CTYPE accum,
80+ int64_t accum_idx) {
81+ // Average pooling does not track indexes, so return 0 for
82+ // accum_idx
83+ return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
84+ },
85+ [divisor](const int64_t count, const CTYPE accum) {
86+ return accum / static_cast <CTYPE>(divisor);
87+ },
88+ count_include_pad,
89+ in,
90+ kernel_size,
91+ stride,
92+ padding,
93+ {},
94+ out);
95+ } else {
96+ apply_kernel_2d_reduce_then_map_fn<CTYPE>(
97+ [](const CTYPE in_val,
98+ int64_t in_idx,
99+ CTYPE accum,
100+ int64_t accum_idx) {
101+ // Average pooling does not track indexes, so return 0 for
102+ // accum_idx
103+ return std::tuple<CTYPE, int64_t >(in_val + accum, 0 );
104+ },
105+ [](const int64_t count, const CTYPE accum) {
106+ return accum / static_cast <CTYPE>(count);
107+ },
108+ count_include_pad,
109+ in,
110+ kernel_size,
111+ stride,
112+ padding,
113+ {},
114+ out);
115+ }
116+ });
114117
115118 return out;
116119}
0 commit comments