@@ -21,6 +21,7 @@ namespace {
2121
2222template <typename CTYPE_IN, typename CTYPE_OUT>
2323void compute_variance (
24+ KernelRuntimeContext& ctx,
2425 const Tensor& in,
2526 Tensor& out,
2627 optional<ArrayRef<int64_t >> dim_list,
@@ -33,22 +34,26 @@ void compute_variance(
3334 }
3435 } else {
3536 MapReduceOverDimListPlan plan (in, dim_list);
36- for (const auto out_ix : c10::irange (out.numel ())) {
37- CTYPE_OUT sum = plan.execute <CTYPE_IN, CTYPE_OUT>(
38- [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
39- [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
40- out_ix);
41- CTYPE_OUT mean = sum / static_cast <CTYPE_OUT>(num);
42- CTYPE_OUT sum2 = plan.execute <CTYPE_IN, CTYPE_OUT>(
43- [mean](CTYPE_IN v) {
44- return (
45- (static_cast <CTYPE_OUT>(v) - mean) *
46- (static_cast <CTYPE_OUT>(v) - mean));
47- },
48- [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
49- out_ix);
50- out_data[out_ix] = sum2 / denominator;
51- }
37+ const bool success = parallel_for_each_reduce_over_dim_list_output_index (
38+ in, dim_list, out, [&](const auto begin, const auto end) {
39+ for (const auto out_ix : c10::irange (begin, end)) {
40+ CTYPE_OUT sum = plan.execute <CTYPE_IN, CTYPE_OUT>(
41+ [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
42+ [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
43+ out_ix);
44+ CTYPE_OUT mean = sum / static_cast <CTYPE_OUT>(num);
45+ CTYPE_OUT sum2 = plan.execute <CTYPE_IN, CTYPE_OUT>(
46+ [mean](CTYPE_IN v) {
47+ return (
48+ (static_cast <CTYPE_OUT>(v) - mean) *
49+ (static_cast <CTYPE_OUT>(v) - mean));
50+ },
51+ [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
52+ out_ix);
53+ out_data[out_ix] = sum2 / denominator;
54+ }
55+ });
56+ ET_KERNEL_CHECK_MSG (ctx, success, Internal, , " parallel_for failed" );
5257 }
5358}
5459
@@ -90,7 +95,7 @@ Tensor& var_out(
9095
9196 ET_SWITCH_FLOATHBF16_TYPES (in.scalar_type (), ctx, name, CTYPE_IN, [&] {
9297 ET_SWITCH_FLOATHBF16_TYPES (out.scalar_type (), ctx, name, CTYPE_OUT, [&] {
93- compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
98+ compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
9499 });
95100 });
96101
@@ -135,7 +140,7 @@ Tensor& var_correction_out(
135140
136141 ET_SWITCH_FLOATHBF16_TYPES (in.scalar_type (), ctx, name, CTYPE_IN, [&] {
137142 ET_SWITCH_FLOATHBF16_TYPES (out.scalar_type (), ctx, name, CTYPE_OUT, [&] {
138- compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
143+ compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
139144 });
140145 });
141146
0 commit comments