Skip to content

Commit 20be077

Browse files
jiayisunxpytorchmergebot
authored andcommitted
[Inductor] support masked vectorization for the tail_loop for float64 datatype (pytorch#163316)
**Summary:** Support masked vectorization for the tail_loop for float64 datatype. **Example:** ``` import torch def fn(x): return x * x x = torch.randn((22, 22), dtype=torch.double) with torch.no_grad(): compiled_fn = torch.compile(fn) compiled_fn(x) ``` **Generated code:** - Before ``` cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(const double* in_ptr0, double* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L))) { auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); auto tmp1 = tmp0 * tmp0; tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); } if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L))) { for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++) { auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)]; auto tmp1 = double(tmp0 * tmp0); out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1; } } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (22, 22), (22, 1)) buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64) # [Provenance debug handles] cpp_fused_mul_0:1 cpp_fused_mul_0(arg0_1, buf0) del arg0_1 return (buf0, ) ``` - After ``` cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(const double* in_ptr0, double* out_ptr0) { { for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L))) { auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); auto tmp1 = tmp0 * tmp0; tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16)); } if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L))) { auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L)); auto tmp1 = tmp0 * tmp0; tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L)); } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, = args args.clear() assert_size_stride(arg0_1, (22, 22), (22, 1)) buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64) # [Provenance debug handles] cpp_fused_mul_0:1 cpp_fused_mul_0(arg0_1, buf0) del arg0_1 return (buf0, ) ``` Pull Request resolved: pytorch#163316 Approved by: https://github.com/mingfeima, https://github.com/jansel
1 parent 94eaeb9 commit 20be077

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4810,6 +4810,23 @@ def fn(x):
48104810
self.common(fn, (x,))
48114811
check_metrics_vec_kernel_count(1)
48124812

4813+
# Tail vectorization case
4814+
x = torch.randn((37, 37), dtype=torch.double)
4815+
torch._dynamo.reset()
4816+
metrics.reset()
4817+
with torch.no_grad():
4818+
expected = fn(x)
4819+
compiled_fn = torch.compile(fn)
4820+
actual, code = run_and_get_cpp_code(compiled_fn, x)
4821+
self.assertEqual(expected, actual)
4822+
# 1 generated vec kernel
4823+
check_metrics_vec_kernel_count(1)
4824+
# Check that both main and tail loops are vectorized
4825+
if _can_check_vec_metrics():
4826+
FileCheck().check_count(
4827+
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
4828+
).run(code)
4829+
48134830
def test_double_reduction_vec(self):
48144831
def fn(x):
48154832
return x.sum(dim=1)
@@ -4819,6 +4836,23 @@ def fn(x):
48194836
self.common(fn, (x,))
48204837
check_metrics_vec_kernel_count(1)
48214838

4839+
# Tail vectorization case
4840+
x = torch.randn((37, 37), dtype=torch.double)
4841+
torch._dynamo.reset()
4842+
metrics.reset()
4843+
with torch.no_grad():
4844+
expected = fn(x)
4845+
compiled_fn = torch.compile(fn)
4846+
actual, code = run_and_get_cpp_code(compiled_fn, x)
4847+
self.assertEqual(expected, actual)
4848+
# 1 generated vec kernel
4849+
check_metrics_vec_kernel_count(1)
4850+
# Check that both main and tail loops are vectorized
4851+
if _can_check_vec_metrics():
4852+
FileCheck().check_count(
4853+
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
4854+
).run(code)
4855+
48224856
def test_convert_fp32_to_double_vec(self):
48234857
def fn(x):
48244858
return x.to(torch.double)
@@ -4828,6 +4862,23 @@ def fn(x):
48284862
self.common(fn, (x,))
48294863
check_metrics_vec_kernel_count(1)
48304864

4865+
# Tail vectorization case
4866+
x = torch.randn(37, 37)
4867+
torch._dynamo.reset()
4868+
metrics.reset()
4869+
with torch.no_grad():
4870+
expected = fn(x)
4871+
compiled_fn = torch.compile(fn)
4872+
actual, code = run_and_get_cpp_code(compiled_fn, x)
4873+
self.assertEqual(expected, actual)
4874+
# 1 generated vec kernel
4875+
check_metrics_vec_kernel_count(1)
4876+
# Check that both main and tail loops are vectorized
4877+
if _can_check_vec_metrics():
4878+
FileCheck().check_count(
4879+
"at::vec::convert<double,2,float,1>", 2, exactly=True
4880+
).run(code)
4881+
48314882
def test_convert_double_to_fp32_vec(self):
48324883
def fn(x):
48334884
return x.to(torch.float32)
@@ -4837,6 +4888,23 @@ def fn(x):
48374888
self.common(fn, (x,))
48384889
check_metrics_vec_kernel_count(1)
48394890

4891+
# Tail vectorization case
4892+
x = torch.randn((37, 37), dtype=torch.double)
4893+
torch._dynamo.reset()
4894+
metrics.reset()
4895+
with torch.no_grad():
4896+
expected = fn(x)
4897+
compiled_fn = torch.compile(fn)
4898+
actual, code = run_and_get_cpp_code(compiled_fn, x)
4899+
self.assertEqual(expected, actual)
4900+
# 1 generated vec kernel
4901+
check_metrics_vec_kernel_count(1)
4902+
# Check that both main and tail loops are vectorized
4903+
if _can_check_vec_metrics():
4904+
FileCheck().check_count(
4905+
"at::vec::convert<float,1,double,2>", 2, exactly=True
4906+
).run(code)
4907+
48404908
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
48414909
# https://github.com/pytorch/pytorch/issues/115260
48424910
p0 = torch.tensor([1.0879], dtype=torch.float16)

torch/_inductor/codegen/cpp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def get_export_declaration():
159159
]
160160

161161
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
162+
torch.float64,
162163
torch.float,
163164
torch.bfloat16,
164165
torch.float16,

0 commit comments

Comments
 (0)