Commit 20be077
[Inductor] support masked vectorization for the tail_loop for float64 datatype (pytorch#163316)
**Summary:**
Support masked vectorization for the tail_loop for float64 datatype.
**Example:**
```
import torch
def fn(x):
return x * x
x = torch.randn((22, 22), dtype=torch.double)
with torch.no_grad():
compiled_fn = torch.compile(fn)
compiled_fn(x)
```
**Generated code:**
- Before
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const double* in_ptr0,
double* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
{
for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++)
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
auto tmp1 = double(tmp0 * tmp0);
out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (22, 22), (22, 1))
buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
# [Provenance debug handles] cpp_fused_mul_0:1
cpp_fused_mul_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
- After
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const double* in_ptr0,
double* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
{
auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
auto tmp1 = tmp0 * tmp0;
tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (22, 22), (22, 1))
buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
# [Provenance debug handles] cpp_fused_mul_0:1
cpp_fused_mul_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
Pull Request resolved: pytorch#163316
Approved by: https://github.com/mingfeima, https://github.com/jansel1 parent 94eaeb9 commit 20be077
File tree
2 files changed
+69
-0
lines changed- test/inductor
- torch/_inductor/codegen
2 files changed
+69
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4810 | 4810 | | |
4811 | 4811 | | |
4812 | 4812 | | |
| 4813 | + | |
| 4814 | + | |
| 4815 | + | |
| 4816 | + | |
| 4817 | + | |
| 4818 | + | |
| 4819 | + | |
| 4820 | + | |
| 4821 | + | |
| 4822 | + | |
| 4823 | + | |
| 4824 | + | |
| 4825 | + | |
| 4826 | + | |
| 4827 | + | |
| 4828 | + | |
| 4829 | + | |
4813 | 4830 | | |
4814 | 4831 | | |
4815 | 4832 | | |
| |||
4819 | 4836 | | |
4820 | 4837 | | |
4821 | 4838 | | |
| 4839 | + | |
| 4840 | + | |
| 4841 | + | |
| 4842 | + | |
| 4843 | + | |
| 4844 | + | |
| 4845 | + | |
| 4846 | + | |
| 4847 | + | |
| 4848 | + | |
| 4849 | + | |
| 4850 | + | |
| 4851 | + | |
| 4852 | + | |
| 4853 | + | |
| 4854 | + | |
| 4855 | + | |
4822 | 4856 | | |
4823 | 4857 | | |
4824 | 4858 | | |
| |||
4828 | 4862 | | |
4829 | 4863 | | |
4830 | 4864 | | |
| 4865 | + | |
| 4866 | + | |
| 4867 | + | |
| 4868 | + | |
| 4869 | + | |
| 4870 | + | |
| 4871 | + | |
| 4872 | + | |
| 4873 | + | |
| 4874 | + | |
| 4875 | + | |
| 4876 | + | |
| 4877 | + | |
| 4878 | + | |
| 4879 | + | |
| 4880 | + | |
| 4881 | + | |
4831 | 4882 | | |
4832 | 4883 | | |
4833 | 4884 | | |
| |||
4837 | 4888 | | |
4838 | 4889 | | |
4839 | 4890 | | |
| 4891 | + | |
| 4892 | + | |
| 4893 | + | |
| 4894 | + | |
| 4895 | + | |
| 4896 | + | |
| 4897 | + | |
| 4898 | + | |
| 4899 | + | |
| 4900 | + | |
| 4901 | + | |
| 4902 | + | |
| 4903 | + | |
| 4904 | + | |
| 4905 | + | |
| 4906 | + | |
| 4907 | + | |
4840 | 4908 | | |
4841 | 4909 | | |
4842 | 4910 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
159 | 159 | | |
160 | 160 | | |
161 | 161 | | |
| 162 | + | |
162 | 163 | | |
163 | 164 | | |
164 | 165 | | |
| |||
0 commit comments