@@ -60,7 +60,7 @@ int prepare_data(
6060 return num_axis_dims;
6161}
6262
63- Tensor& mean_dim_out (
63+ Tensor& mean_out (
6464 KernelRuntimeContext& ctx,
6565 const Tensor& in,
6666 optional<ArrayRef<int64_t >> dim_list,
@@ -169,29 +169,32 @@ Tensor& mean_dim_out(
169169 InvalidArgument,
170170 out);
171171
172- ET_SWITCH_REALHB_TYPES (in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
173- ET_SWITCH_FLOATH_TYPES (
174- out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
175- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
176- const size_t num =
177- torch::executor::get_reduced_dim_product (in, dim_list);
178- for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
179- CTYPE_OUT sum = 0 ;
180- if (in.numel () > 0 ) {
181- sum = torch::executor::
182- map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
183- [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
184- [](CTYPE_OUT outv, CTYPE_OUT acc) {
185- return acc + outv;
186- },
187- in,
188- dim_list,
189- out_ix);
190- }
191- out_data[out_ix] = sum / static_cast <float >(num);
192- }
193- });
194- });
172+ ET_SWITCH_REALHBBF16_TYPES (
173+ in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
174+ ET_SWITCH_FLOATHBF16_TYPES (
175+ out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
176+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
177+ const size_t num =
178+ torch::executor::get_reduced_dim_product (in, dim_list);
179+ for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
180+ CTYPE_OUT sum = 0 ;
181+ if (in.numel () > 0 ) {
182+ sum = torch::executor::
183+ map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
184+ [](CTYPE_IN v) {
185+ return static_cast <CTYPE_OUT>(v);
186+ },
187+ [](CTYPE_OUT outv, CTYPE_OUT acc) {
188+ return acc + outv;
189+ },
190+ in,
191+ dim_list,
192+ out_ix);
193+ }
194+ out_data[out_ix] = sum / static_cast <float >(num);
195+ }
196+ });
197+ });
195198 }
196199
197200 return out;
0 commit comments