@@ -28,6 +28,29 @@ namespace impl {
2828namespace G3 {
2929namespace native {
3030
31+ template <typename CTYPE_IN, typename CTYPE_OUT>
32+ void mean_out_ (
33+ const Tensor& in,
34+ optional<ArrayRef<int64_t >> dim_list,
35+ __ET_UNUSED bool keepdim,
36+ __ET_UNUSED optional<ScalarType> dtype,
37+ Tensor& out) {
38+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
39+ const size_t num = torch::executor::get_reduced_dim_product (in, dim_list);
40+ for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
41+ CTYPE_OUT sum = 0 ;
42+ if (in.numel () > 0 ) {
43+ sum = torch::executor::map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
44+ [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
45+ [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
46+ in,
47+ dim_list,
48+ out_ix);
49+ }
50+ out_data[out_ix] = sum / static_cast <float >(num);
51+ }
52+ }
53+
3154int prepare_data (
3255 const Tensor& in,
3356 Tensor& out,
@@ -60,7 +83,7 @@ int prepare_data(
6083 return num_axis_dims;
6184}
6285
63- Tensor& mean_dim_out (
86+ Tensor& mean_out (
6487 KernelRuntimeContext& ctx,
6588 const Tensor& in,
6689 optional<ArrayRef<int64_t >> dim_list,
@@ -169,29 +192,8 @@ Tensor& mean_dim_out(
169192 InvalidArgument,
170193 out);
171194
172- ET_SWITCH_REALHB_TYPES (in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
173- ET_SWITCH_FLOATH_TYPES (
174- out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
175- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
176- const size_t num =
177- torch::executor::get_reduced_dim_product (in, dim_list);
178- for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
179- CTYPE_OUT sum = 0 ;
180- if (in.numel () > 0 ) {
181- sum = torch::executor::
182- map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
183- [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
184- [](CTYPE_OUT outv, CTYPE_OUT acc) {
185- return acc + outv;
186- },
187- in,
188- dim_list,
189- out_ix);
190- }
191- out_data[out_ix] = sum / static_cast <float >(num);
192- }
193- });
194- });
195+ mean_out_<float , float >(in, dim_list, keepdim, dtype, out);
196+ return out;
195197 }
196198
197199 return out;
0 commit comments