Skip to content

Commit fa21963

Browse files
lingebengmalfet
authored andcommitted
[MPS] Add input/indices shape validation for MaxUnpool{1,2,3}d (pytorch#169261)
Add missing shape validation between `input` and `indices` tensors for `nn.MaxUnpool{1,2,3}d` on MPS backend Fixes pytorch#169235 Pull Request resolved: pytorch#169261 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <[email protected]>
1 parent 3cf2f19 commit fa21963

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

aten/src/ATen/native/mps/operations/Pooling.mm

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,13 @@ static void max_unpool_out_mps_template(const Tensor& input,
570570
" elements but got ",
571571
output_size_.size());
572572

573+
// Check that input and indices have the same shape
574+
TORCH_CHECK(input.sizes() == indices.sizes(),
575+
"Expected shape of indices to be same as that of the input tensor (",
576+
input.sizes(),
577+
") but got indices tensor with shape: ",
578+
indices.sizes());
579+
573580
auto dims = input.dim();
574581
auto leading_dims = input.dim() - pooling_dims;
575582

torch/testing/_internal/common_methods_invocations.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9384,22 +9384,52 @@ def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs):
93849384
yield SampleInput(make_input((2,)), offset=1)
93859385
yield SampleInput(make_input((2,)), offset=-1)
93869386

9387+
9388+
_UNPOOL_NAME_TO_DIM = {
9389+
'nn.functional.max_unpool1d': 1,
9390+
'nn.functional.max_unpool2d': 2,
9391+
'nn.functional.max_unpool3d': 3
9392+
}
9393+
9394+
9395+
def error_inputs_max_unpool(op_info, device, **kwargs):
9396+
"""Error inputs for max_unpool: shape mismatch between input and indices."""
9397+
make_arg = partial(make_tensor, device=device, dtype=torch.float32)
9398+
pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name]
9399+
9400+
# Create mismatched shapes for input and indices
9401+
kwargs_dict = {'kernel_size': 3, 'stride': 2, 'padding': 0}
9402+
if pool_dim == 1:
9403+
input_shape = (8, 8)
9404+
indices_shape = (8, 7)
9405+
elif pool_dim == 2:
9406+
input_shape = (1, 1, 4, 4)
9407+
indices_shape = (1, 1, 4, 1)
9408+
else: # pool_dim == 3
9409+
input_shape = (1, 1, 4, 4, 4)
9410+
indices_shape = (1, 1, 4, 4, 1)
9411+
9412+
yield ErrorInput(
9413+
SampleInput(
9414+
make_arg(input_shape),
9415+
args=(torch.zeros(indices_shape, device=device, dtype=torch.long),),
9416+
kwargs=kwargs_dict
9417+
),
9418+
error_type=RuntimeError,
9419+
error_regex='Expected shape of indices to be'
9420+
)
9421+
9422+
93879423
def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
93889424
unpool_name_to_pool_method_dict = {
93899425
'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d,
93909426
'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d,
93919427
'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d
93929428
}
93939429

9394-
unpool_name_to_dim = {
9395-
'nn.functional.max_unpool1d': 1,
9396-
'nn.functional.max_unpool2d': 2,
9397-
'nn.functional.max_unpool3d': 3
9398-
}
9399-
94009430
unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()}
94019431

9402-
pool_dim = unpool_name_to_dim[op_info.name]
9432+
pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name]
94039433
pool_method = unpool_name_to_pool_method_dict[op_info.name]
94049434

94059435
pool_op_info = copy.copy(op_info)
@@ -16252,6 +16282,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1625216282
assert_jit_shape_analysis=False,
1625316283
dtypes=floating_types_and(torch.float16, torch.bfloat16),
1625416284
sample_inputs_func=sample_inputs_max_unpool,
16285+
error_inputs_func=error_inputs_max_unpool,
1625516286
skips=(
1625616287
# Gradients are tested in `variant_test_name=grad` below.
1625716288
# We skip tests here because there is non-determinism in backward
@@ -16286,6 +16317,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1628616317
assert_jit_shape_analysis=False,
1628716318
dtypes=floating_types_and(torch.float16, torch.bfloat16),
1628816319
sample_inputs_func=sample_inputs_max_unpool,
16320+
error_inputs_func=error_inputs_max_unpool,
1628916321
skips=(
1629016322
# Gradients are tested in `variant_test_name=grad` below.
1629116323
# We skip tests here because there is non-determinism in backward
@@ -16323,6 +16355,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1632316355
assert_jit_shape_analysis=False,
1632416356
dtypes=floating_types_and(torch.float16, torch.bfloat16),
1632516357
sample_inputs_func=sample_inputs_max_unpool,
16358+
error_inputs_func=error_inputs_max_unpool,
1632616359
skips=(
1632716360
# Gradients are tested in `variant_test_name=grad` below.
1632816361
# We skip tests here because there is non-determinism in backward

0 commit comments

Comments
 (0)