Skip to content

Commit efc4b46

Browse files
CaoEpytorchmergebot
authored andcommitted
Add cascade sum support for Inductor CPP backend (pytorch#156296)
Fixes pytorch#154703 Add cascade summation support for Inductor CPP backend to improve precision for large size summation. Currently, Inductor CPP directly do reduction for sum. As shown in pytorch#154703, when the size of the sum is large and the number of parallel is small, direct reduction will cause an intolerable precision loss: ``` extern "C" void kernel(float* in_out_ptr0, const float* in_ptr0) { auto out_ptr0 = in_out_ptr0; { { float tmp_acc0 = 0; at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0); for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3000000000L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(3000000000L))) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); tmp_acc0_vec = tmp_acc0_vec + tmp0; } } } tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec); out_ptr0[static_cast<int64_t>(0L)] = static_cast<float>(tmp_acc0); } } { { { auto tmp0 = out_ptr0[static_cast<int64_t>(0L)]; auto tmp1 = static_cast<float>(3000000000.0); auto tmp2 = tmp0 / tmp1; in_out_ptr0[static_cast<int64_t>(0L)] = tmp2; } } } } ``` After adding cascade sum support: ``` extern "C" void kernel(float* in_out_ptr0, const float* in_ptr0) { auto out_ptr0 = in_out_ptr0; { { float tmp_acc0 = 0; at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0); at::vec::Vectorized<float> masked_tmp_acc0_vec = at::vec::Vectorized<float>(0); CascadeSumHelper<float, 65536> scalar_cascade_helper0(static_cast<int64_t>(3000000000L)); CascadeSumHelper<at::vec::Vectorized<float>, 65536> cascade_helper0(static_cast<int64_t>(187500000L)); CascadeSumHelper<at::vec::Vectorized<float>, 65536> masked_cascade_helper0(static_cast<int64_t>(0L)); for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(3000000000L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(3000000000L))) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); tmp_acc0_vec = cascade_sum_combine(tmp0, &cascade_helper0); } } } tmp_acc0 = cascade_sum_final(&scalar_cascade_helper0); tmp_acc0_vec = cascade_sum_final(&cascade_helper0); masked_tmp_acc0_vec = cascade_sum_final(&masked_cascade_helper0); tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float, 1>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec + masked_tmp_acc0_vec); out_ptr0[static_cast<int64_t>(0L)] = static_cast<float>(tmp_acc0); } } { { { auto tmp0 = out_ptr0[static_cast<int64_t>(0L)]; auto tmp1 = static_cast<float>(3000000000.0); auto tmp2 = tmp0 / tmp1; in_out_ptr0[static_cast<int64_t>(0L)] = tmp2; } } } } ``` This will inevitably reduce performance when cascade sum is turned on. For the case shown in pytorch#154703: performance reduced by ~3%. Pull Request resolved: pytorch#156296 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
1 parent 1ca8388 commit efc4b46

File tree

3 files changed

+382
-91
lines changed

3 files changed

+382
-91
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2644,6 +2644,18 @@ def fn(a, dim, index, b):
26442644
self.common(fn, inps)
26452645
assert metrics.generated_cpp_vec_kernel_count == 2
26462646

2647+
def test_large_mean(self):
2648+
size = (30000, 100000)
2649+
t = torch.rand(size, dtype=torch.float)
2650+
op = torch.mean
2651+
expected = op(t)
2652+
actual = torch.compile(op)(t)
2653+
self.assertEqual(expected, actual)
2654+
with set_num_threads(1):
2655+
expected = op(t)
2656+
actual = torch.compile(op)(t)
2657+
self.assertEqual(expected, actual)
2658+
26472659
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
26482660
@requires_vectorization
26492661
@patch("torch.cuda.is_available", lambda: False)

0 commit comments

Comments
 (0)