@@ -9384,22 +9384,52 @@ def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs):
93849384 yield SampleInput(make_input((2,)), offset=1)
93859385 yield SampleInput(make_input((2,)), offset=-1)
93869386
9387+
9388+ _UNPOOL_NAME_TO_DIM = {
9389+ 'nn.functional.max_unpool1d': 1,
9390+ 'nn.functional.max_unpool2d': 2,
9391+ 'nn.functional.max_unpool3d': 3
9392+ }
9393+
9394+
9395+ def error_inputs_max_unpool(op_info, device, **kwargs):
9396+ """Error inputs for max_unpool: shape mismatch between input and indices."""
9397+ make_arg = partial(make_tensor, device=device, dtype=torch.float32)
9398+ pool_dim = _UNPOOL_NAME_TO_DIM[op_info.name]
9399+
9400+ # Create mismatched shapes for input and indices
9401+ kwargs_dict = {'kernel_size': 3, 'stride': 2, 'padding': 0}
9402+ if pool_dim == 1:
9403+ input_shape = (8, 8)
9404+ indices_shape = (8, 7)
9405+ elif pool_dim == 2:
9406+ input_shape = (1, 1, 4, 4)
9407+ indices_shape = (1, 1, 4, 1)
9408+ else: # pool_dim == 3
9409+ input_shape = (1, 1, 4, 4, 4)
9410+ indices_shape = (1, 1, 4, 4, 1)
9411+
9412+ yield ErrorInput(
9413+ SampleInput(
9414+ make_arg(input_shape),
9415+ args=(torch.zeros(indices_shape, device=device, dtype=torch.long),),
9416+ kwargs=kwargs_dict
9417+ ),
9418+ error_type=RuntimeError,
9419+ error_regex='Expected shape of indices to be'
9420+ )
9421+
9422+
93879423def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
93889424 unpool_name_to_pool_method_dict = {
93899425 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d,
93909426 'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d,
93919427 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d
93929428 }
93939429
9394- unpool_name_to_dim = {
9395- 'nn.functional.max_unpool1d': 1,
9396- 'nn.functional.max_unpool2d': 2,
9397- 'nn.functional.max_unpool3d': 3
9398- }
9399-
94009430 unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()}
94019431
9402- pool_dim = unpool_name_to_dim [op_info.name]
9432+ pool_dim = _UNPOOL_NAME_TO_DIM [op_info.name]
94039433 pool_method = unpool_name_to_pool_method_dict[op_info.name]
94049434
94059435 pool_op_info = copy.copy(op_info)
@@ -16252,6 +16282,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1625216282 assert_jit_shape_analysis=False,
1625316283 dtypes=floating_types_and(torch.float16, torch.bfloat16),
1625416284 sample_inputs_func=sample_inputs_max_unpool,
16285+ error_inputs_func=error_inputs_max_unpool,
1625516286 skips=(
1625616287 # Gradients are tested in `variant_test_name=grad` below.
1625716288 # We skip tests here because there is non-determinism in backward
@@ -16286,6 +16317,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1628616317 assert_jit_shape_analysis=False,
1628716318 dtypes=floating_types_and(torch.float16, torch.bfloat16),
1628816319 sample_inputs_func=sample_inputs_max_unpool,
16320+ error_inputs_func=error_inputs_max_unpool,
1628916321 skips=(
1629016322 # Gradients are tested in `variant_test_name=grad` below.
1629116323 # We skip tests here because there is non-determinism in backward
@@ -16323,6 +16355,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1632316355 assert_jit_shape_analysis=False,
1632416356 dtypes=floating_types_and(torch.float16, torch.bfloat16),
1632516357 sample_inputs_func=sample_inputs_max_unpool,
16358+ error_inputs_func=error_inputs_max_unpool,
1632616359 skips=(
1632716360 # Gradients are tested in `variant_test_name=grad` below.
1632816361 # We skip tests here because there is non-determinism in backward
0 commit comments