Skip to content

Commit 27e9ca5

Browse files
pytorchbotmalfet
andauthored
[MKLDNN] Check that strides are positive (pytorch#153092)
[MKLDNN] Check that strides are positive (pytorch#151848) For pooling ops. Prevents division-by-zero when argument is wrong Fixes pytorch#149274 Pull Request resolved: pytorch#151848 Approved by: https://github.com/atalman (cherry picked from commit 6f32712) Co-authored-by: Nikita Shulga <[email protected]>
1 parent dab8130 commit 27e9ca5

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

aten/src/ATen/native/mkldnn/Utils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ std::vector<int64_t> pool_output_sizes(
1919
output_size[1] = input_size[1];
2020

2121
for (const auto i : c10::irange(2, input_size.size())) {
22+
TORCH_CHECK_VALUE(stride[i -2] > 0, "Strides must be positive!");
2223
output_size[i] = pooling_output_shape_pad_lr<int64_t>(
2324
input_size[i],
2425
kernel_size[i - 2],

test/test_mkldnn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,12 @@ def test_mkldnn_setflags_nowarn(self, device):
16231623
# Above should trigger no warnings regardless of configuration
16241624
self.assertEqual(len(w), 0)
16251625

1626+
def test_mkldnn_error_on_zero_stride(self, device):
1627+
# Regression test for https://github.com/pytorch/pytorch/issues/149274
1628+
x = torch.rand(1, 2, 3, 3).to_mkldnn()
1629+
with self.assertRaises(ValueError):
1630+
torch.mkldnn_max_pool2d(x, kernel_size=3, stride=0)
1631+
16261632

16271633
instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))
16281634

0 commit comments

Comments
 (0)