@@ -44,23 +44,24 @@ Tensor& mean_dim_out(
4444 InvalidArgument,
4545 out);
4646
47- ET_SWITCH_REALHB_TYPES (in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
48- ET_SWITCH_FLOATH_TYPES (out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
49- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
50- const size_t num = get_reduced_dim_product (in, dim_list);
51- for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
52- CTYPE_OUT sum = 0 ;
53- if (in.numel () > 0 ) {
54- sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
55- [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
56- [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
57- in,
58- dim_list,
59- out_ix);
60- }
61- out_data[out_ix] = sum / static_cast <float >(num);
62- }
63- });
47+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " mean.out" , CTYPE_IN, [&] {
48+ ET_SWITCH_FLOATHBF16_TYPES (
49+ out.scalar_type (), ctx, " mean.out" , CTYPE_OUT, [&] {
50+ CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
51+ const size_t num = get_reduced_dim_product (in, dim_list);
52+ for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
53+ CTYPE_OUT sum = 0 ;
54+ if (in.numel () > 0 ) {
55+ sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
56+ [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
57+ [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
58+ in,
59+ dim_list,
60+ out_ix);
61+ }
62+ out_data[out_ix] = sum / static_cast <float >(num);
63+ }
64+ });
6465 });
6566
6667 return out;
0 commit comments