Skip to content

Commit f1c49c9

Browse files
krastogi-inpytorchmergebot
authored andcommitted
Checking if the input is finite before calculation in lowering of pow func (pytorch#167723)
Fixes pytorch#167197 The inductor backend is trying to convert the float infinity value to an integer in pow lowering (possibly for indexing, iteration counts, or type conversions). Python/C++ cannot convert float('inf') to an integer, causing the overflow error Pull Request resolved: pytorch#167723 Approved by: https://github.com/shunting314
1 parent 265397e commit f1c49c9

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5528,6 +5528,32 @@ def fn(x):
55285528
check_lowp=not is_halide_backend(self.device), # misaligned addr fp16
55295529
)
55305530

5531+
def test_lp_pool1d_with_inf_norm(self):
5532+
# https://github.com/pytorch/pytorch/issues/167197
5533+
# Test that LPPool1d works with infinity norm (should behave like max pooling)
5534+
def fn(x):
5535+
return torch.nn.functional.lp_pool1d(
5536+
x, norm_type=float("inf"), kernel_size=2, stride=2
5537+
)
5538+
5539+
self.common(
5540+
fn,
5541+
(torch.randn(3, 4, 8),),
5542+
)
5543+
5544+
def test_lp_pool2d_with_inf_norm(self):
5545+
# https://github.com/pytorch/pytorch/issues/167197
5546+
# Test that LPPool2d works with infinity norm (should behave like max pooling)
5547+
def fn(x):
5548+
return torch.nn.functional.lp_pool2d(
5549+
x, norm_type=float("inf"), kernel_size=2, stride=2
5550+
)
5551+
5552+
self.common(
5553+
fn,
5554+
(torch.randn(3, 4, 8, 8),),
5555+
)
5556+
55315557
@tf32_on_and_off(0.006)
55325558
@skip_if_gpu_halide # slow
55335559
def test_alexnet_prefix(self):
@@ -6307,6 +6333,15 @@ def fn(x):
63076333
x = torch.randn([16, 16], device=self.device)
63086334
self.assertEqual(cfn(x), fn(x))
63096335

6336+
def test_pow_infinite(self):
6337+
def fn(a, b):
6338+
return torch.pow(a, b)
6339+
6340+
opt = torch.compile(fn, backend="inductor")
6341+
a = torch.randn((3, 4, 8), device=self.device)
6342+
b = float("inf")
6343+
self.assertTrue(same(opt(a, b), fn(a, b)))
6344+
63106345
def test_glu(self):
63116346
def fn(x):
63126347
return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2)

torch/_inductor/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6361,7 +6361,7 @@ def pow_native(a, b):
63616361

63626362
@register_lowering(aten.pow, broadcast=True)
63636363
def pow(a, b):
6364-
if isinstance(b, float) and b == int(b):
6364+
if isinstance(b, float) and math.isfinite(b) and b == int(b):
63656365
return pow(a, int(b))
63666366
elif isinstance(b, float) and b == 0.5:
63676367
return sqrt(a)

0 commit comments

Comments
 (0)