Skip to content

Commit 0c936c3

Browse files
isurufpytorchmergebot
authored andcommitted
Add decomps for max_unpool (pytorch#133146)
Pull Request resolved: pytorch#133146 Approved by: https://github.com/amjames, https://github.com/eellison
1 parent 293fccf commit 0c936c3

File tree

6 files changed

+133
-134
lines changed

6 files changed

+133
-134
lines changed

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -925,10 +925,6 @@ aten::max_pool3d_with_indices
925925
aten::max_pool3d_with_indices.out
926926
aten::max_pool3d_with_indices_backward
927927
aten::max_pool3d_with_indices_backward.grad_input
928-
aten::max_unpool2d
929-
aten::max_unpool2d.out
930-
aten::max_unpool3d
931-
aten::max_unpool3d.out
932928
aten::median
933929
aten::median.dim
934930
aten::median.dim_values

torch/_decomp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,8 @@ def _core_aten_decompositions_post_autograd() -> (
541541
aten.logsumexp.default,
542542
aten.masked_fill,
543543
aten.masked_fill_,
544+
aten.max_unpool2d,
545+
aten.max_unpool3d,
544546
aten.mish,
545547
aten.mish_,
546548
aten.mse_loss,

torch/_decomp/decompositions.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,6 +2568,134 @@ def maybe_mask(vals, length, range_max, adaptive, dim):
25682568
return ret / (length_h * length_w)
25692569

25702570

2571+
def _max_unpoolnd(
2572+
self: TensorLike, indices: TensorLike, output_size: List[int], dim: int
2573+
):
2574+
# If the input tensors self and indices came from max_pool call as
2575+
# required by the documentation, this operation is deterministic
2576+
# because that ensures that if there are two entries in `indices`
2577+
# tensor that are equal, the corresponding values in `self` are also
2578+
# equal. If this condition is not satisfied, the operation is
2579+
# non-deterministic as one of the different values in `self` 'wins'.
2580+
utils.alert_not_deterministic(f"max_unpooling{dim}d_forward_out")
2581+
nc = reduce(operator.mul, self.shape[:-dim])
2582+
hw = reduce(operator.mul, output_size)
2583+
indices_nc_shape = [1] * self.ndim
2584+
indices_nc_shape[:-dim] = self.shape[:-dim]
2585+
indices_flat = (
2586+
indices + aten.arange(nc, device=self.device).view(indices_nc_shape) * hw
2587+
).reshape(-1)
2588+
2589+
output = self.new_zeros(list(self.shape[:-dim]) + list(output_size))
2590+
return aten._unsafe_index_put(
2591+
output.reshape(-1), [indices_flat], self.reshape(-1), accumulate=False
2592+
).view(output.shape)
2593+
2594+
2595+
@register_decomposition(aten.max_unpool2d)
2596+
@out_wrapper()
2597+
def max_unpool2d(
2598+
self: TensorLike,
2599+
indices: TensorLike,
2600+
output_size: List[int],
2601+
):
2602+
torch._check(
2603+
indices.dtype == torch.int64,
2604+
lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
2605+
)
2606+
torch._check(
2607+
len(output_size) == 2,
2608+
lambda: (
2609+
f"There should be exactly two elements (height, width) in output_size, "
2610+
f"but got {len(output_size)} elements."
2611+
),
2612+
)
2613+
2614+
torch._check(
2615+
self.ndim in (3, 4),
2616+
lambda: (
2617+
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
2618+
f"but got a tensor with {self.ndim} dimensions."
2619+
),
2620+
)
2621+
torch._check(
2622+
self.shape == indices.shape,
2623+
lambda: (
2624+
f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
2625+
f"but got indices tensor with shape: {indices.shape}"
2626+
),
2627+
)
2628+
2629+
for i in range(1, self.ndim):
2630+
torch._check(
2631+
self.size(i) > 0,
2632+
lambda: (
2633+
f"max_unpooling2d(): "
2634+
f"Expected input to have non-zero size for non-batch dimensions, "
2635+
f"but got {self.shape} with dimension {i} being empty."
2636+
),
2637+
)
2638+
2639+
return _max_unpoolnd(self, indices, output_size, 2)
2640+
2641+
2642+
@register_decomposition(aten.max_unpool3d)
2643+
@out_wrapper()
2644+
def max_unpool3d(
2645+
input: TensorLike,
2646+
indices: TensorLike,
2647+
output_size: List[int],
2648+
stride: List[int],
2649+
padding: List[int],
2650+
):
2651+
torch._check(
2652+
indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
2653+
)
2654+
torch._check(
2655+
input.ndim in (4, 5),
2656+
lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
2657+
)
2658+
torch._check(
2659+
len(output_size) == 3,
2660+
lambda: (
2661+
f"There should be exactly three elements (depth, height, width) in output_size, "
2662+
f"but got {len(output_size)} elements."
2663+
),
2664+
)
2665+
torch._check(
2666+
len(stride) == 3,
2667+
lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
2668+
)
2669+
torch._check(
2670+
len(padding) == 3,
2671+
lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
2672+
)
2673+
torch._check(
2674+
input.shape == indices.shape,
2675+
lambda: (
2676+
f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
2677+
f"but got indices tensor with shape: {indices.shape}"
2678+
),
2679+
)
2680+
2681+
for i in range(1, input.ndim):
2682+
torch._check(
2683+
input.size(i) > 0,
2684+
lambda: (
2685+
f"max_unpooling3d(): "
2686+
f"Expected input to have non-zero size for non-batch dimensions, "
2687+
f"but got {input.shape} with dimension {i} being empty."
2688+
),
2689+
)
2690+
2691+
torch._check(
2692+
stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
2693+
lambda: f"strides should be greater than zero, but got stride: {stride}",
2694+
)
2695+
2696+
return _max_unpoolnd(input, indices, output_size, 3)
2697+
2698+
25712699
@register_decomposition(aten.index_add_)
25722700
def index_add_(
25732701
x: TensorLike,

torch/_inductor/lowering.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,8 +2215,6 @@ def is_aligned(x):
22152215
make_fallback(aten._cdist_backward)
22162216

22172217
# 2) Medium
2218-
make_fallback(aten.max_unpool2d)
2219-
make_fallback(aten.max_unpool3d)
22202218
make_fallback(aten._trilinear)
22212219

22222220

torch/_meta_registrations.py

Lines changed: 0 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -4329,134 +4329,6 @@ def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
43294329
)
43304330

43314331

4332-
@register_meta(aten.max_unpool2d)
4333-
@out_wrapper()
4334-
def meta_max_unpool2d(self, indices, output_size):
4335-
utils.alert_not_deterministic("max_unpooling2d_forward_out")
4336-
4337-
torch._check(
4338-
indices.dtype == torch.int64,
4339-
lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
4340-
)
4341-
torch._check(
4342-
len(output_size) == 2,
4343-
lambda: (
4344-
f"There should be exactly two elements (height, width) in output_size, "
4345-
f"but got {len(output_size)} elements."
4346-
),
4347-
)
4348-
4349-
oheight, owidth = output_size
4350-
4351-
torch._check(
4352-
self.ndim in (3, 4),
4353-
lambda: (
4354-
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
4355-
f"but got a tensor with {self.ndim} dimensions."
4356-
),
4357-
)
4358-
torch._check(
4359-
self.shape == indices.shape,
4360-
lambda: (
4361-
f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
4362-
f"but got indices tensor with shape: {indices.shape}"
4363-
),
4364-
)
4365-
4366-
for i in range(1, self.ndim):
4367-
torch._check(
4368-
self.size(i) > 0,
4369-
lambda: (
4370-
f"max_unpooling2d(): "
4371-
f"Expected input to have non-zero size for non-batch dimensions, "
4372-
f"but got {self.shape} with dimension {i} being empty."
4373-
),
4374-
)
4375-
4376-
self = self.contiguous()
4377-
4378-
if self.ndim == 3:
4379-
nchannels = self.size(0)
4380-
result = self.new_empty((nchannels, oheight, owidth))
4381-
else:
4382-
nbatch = self.size(0)
4383-
nchannels = self.size(1)
4384-
result = self.new_empty((nbatch, nchannels, oheight, owidth))
4385-
4386-
return result
4387-
4388-
4389-
def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name):
4390-
torch._check(
4391-
indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
4392-
)
4393-
torch._check(
4394-
input.ndim in (4, 5),
4395-
lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
4396-
)
4397-
torch._check(
4398-
len(output_size) == 3,
4399-
lambda: (
4400-
f"There should be exactly three elements (depth, height, width) in output_size, "
4401-
f"but got {len(output_size)} elements."
4402-
),
4403-
)
4404-
torch._check(
4405-
len(stride) == 3,
4406-
lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
4407-
)
4408-
torch._check(
4409-
len(padding) == 3,
4410-
lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
4411-
)
4412-
torch._check(
4413-
input.shape == indices.shape,
4414-
lambda: (
4415-
f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
4416-
f"but got indices tensor with shape: {indices.shape}"
4417-
),
4418-
)
4419-
4420-
for i in range(1, input.ndim):
4421-
torch._check(
4422-
input.size(i) > 0,
4423-
lambda: (
4424-
f"{fn_name}: "
4425-
f"Expected input to have non-zero size for non-batch dimensions, "
4426-
f"but got {input.shape} with dimension {i} being empty."
4427-
),
4428-
)
4429-
4430-
torch._check(
4431-
stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
4432-
lambda: f"strides should be greater than zero, but got stride: {stride}",
4433-
)
4434-
4435-
4436-
@register_meta(aten.max_unpool3d)
4437-
@out_wrapper()
4438-
def meta_max_unpool3d(self, indices, output_size, stride, padding):
4439-
utils.alert_not_deterministic("max_unpooling3d_forward_out")
4440-
4441-
_max_unpooling3d_shape_check(
4442-
self, indices, output_size, stride, padding, "max_unpooling3d()"
4443-
)
4444-
4445-
self = self.contiguous()
4446-
4447-
odepth, oheight, owidth = output_size
4448-
4449-
if self.ndim == 4:
4450-
nchannels = self.size(0)
4451-
result = self.new_empty((nchannels, odepth, oheight, owidth))
4452-
else:
4453-
nbatch = self.size(0)
4454-
nchannels = self.size(1)
4455-
result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth))
4456-
4457-
return result
4458-
4459-
44604332
@register_meta(aten.max_pool3d_with_indices)
44614333
@out_wrapper("out", "indices")
44624334
def meta_max_pool3d_with_indices(

torch/testing/_internal/common_methods_invocations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15731,6 +15731,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1573115731
active_if=(not IS_MACOS)),
1573215732
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad',
1573315733
device_type='cpu'),
15734+
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'),
1573415735
)),
1573515736
OpInfo('nn.functional.max_unpool1d',
1573615737
variant_test_name='grad',
@@ -15763,6 +15764,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1576315764
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
1576415765
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
1576515766
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
15767+
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'),
1576615768
)),
1576715769
OpInfo('nn.functional.max_unpool2d',
1576815770
variant_test_name='grad',
@@ -15799,6 +15801,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1579915801
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
1580015802
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
1580115803
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
15804+
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'),
1580215805
)),
1580315806
OpInfo('nn.functional.max_unpool3d',
1580415807
variant_test_name='grad',

0 commit comments

Comments
 (0)