Skip to content

Commit a817baf

Browse files
authored
add dynamic padding substitution for nn.ConvTranspose2d (#1381)
* add dynamic padding substitution for nn.ConvTranspose2d
1 parent f7ac3c9 commit a817baf

File tree

6 files changed

+177
-1
lines changed

6 files changed

+177
-1
lines changed

model_compression_toolkit/core/pytorch/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DILATIONS = 'dilation'
3434
TENSOR_META = 'tensor_meta'
3535
FILTERS = 'out_channels'
36+
OUTPUT_PADDING = 'output_padding'
3637
TYPE = 'type'
3738
PAD = 'pad'
3839
VALUE = 'value'
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from typing import Tuple
16+
import torch.nn as nn
17+
import torch
18+
from model_compression_toolkit.core.pytorch.constants import OUTPUT_PADDING
19+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
20+
from model_compression_toolkit.core import common
21+
from model_compression_toolkit.core.common import BaseNode, Graph
22+
from model_compression_toolkit.logger import Logger
23+
24+
25+
class ConvtransposeDynamicPadding(common.BaseSubstitution):
26+
"""
27+
Replace output_padding of nn.ConvTranspose2d to align dynamic output_size input.
28+
In case there is a dynamic output_size in ConvTranspose2d forward function, we recalculate the
29+
output_padding here according to node.output_shape (which is equal to the dynamic output_size if existed).
30+
"""
31+
32+
def __init__(self):
33+
"""
34+
Matches: nn.ConvTranspose2d
35+
"""
36+
convtr_node = NodeOperationMatcher(nn.ConvTranspose2d)
37+
super().__init__(matcher_instance=convtr_node)
38+
39+
40+
def calc_dynamic_output_size(self, node: BaseNode) -> Tuple[int]:
41+
"""
42+
Calc the output padding to support dunamic output_size of nn.ConvTranspose2d
43+
Args:
44+
node: node to calculate output padding
45+
46+
Returns:
47+
corrected output padding
48+
"""
49+
convtr = nn.ConvTranspose2d(**node.framework_attr)
50+
num_spatial_dims = 2
51+
output_padding = convtr._output_padding(torch.randn(size=node.input_shape[0]),
52+
node.output_shape[0],
53+
convtr.stride,
54+
convtr.padding,
55+
convtr.kernel_size,
56+
num_spatial_dims,
57+
convtr.dilation)
58+
return tuple(output_padding)
59+
60+
61+
def substitute(self,
62+
graph: Graph,
63+
node: BaseNode) -> Graph:
64+
"""
65+
Substitute nn.ConvTranspose2d with corrected output_padding for cases of dynamic output_size
66+
Args:
67+
graph: Graph we apply the substitution on.
68+
node: node that match the pattern in the substitution init.
69+
70+
Returns:
71+
Graph after applying the substitution.
72+
"""
73+
74+
if not node.reuse:
75+
output_padding = self.calc_dynamic_output_size(node)
76+
node.framework_attr.update({OUTPUT_PADDING: output_padding})
77+
return graph

model_compression_toolkit/core/pytorch/pytorch_implementation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
ScaledDotProductDecomposition
6363
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.transform_function_call_method import \
6464
TransformFunctionCallMethod
65+
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.convtranspose_dynamic_padding import \
66+
ConvtransposeDynamicPadding
6567
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \
6668
FunctionalConvSubstitution
6769
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
@@ -286,7 +288,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
286288
FunctionalBatchNorm(),
287289
FunctionalLayerNorm(),
288290
FunctionalLinear(),
289-
RemoveIdentity()]
291+
RemoveIdentity(),
292+
ConvtransposeDynamicPadding()]
290293

291294
def get_substitutions_pre_statistics_collection(self,
292295
quant_config: QuantizationConfig
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import torch
16+
from torch import nn
17+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
18+
AttachTpcToPytorch
19+
20+
from model_compression_toolkit.core import QuantizationConfig
21+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
22+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
23+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
24+
25+
26+
class Model(nn.Module):
27+
def __init__(self):
28+
super().__init__()
29+
self.downsample = nn.Conv2d(3, 16, 3, stride=2, padding=1)
30+
self.upsample = nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1)
31+
self.downsample2 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
32+
self.upsample2 = nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1)
33+
34+
def forward(self, x):
35+
x = self.downsample(x)
36+
x = self.upsample(x, output_size=[224, 224]) # <--- dynamic output_size
37+
x = self.downsample2(x)
38+
x = self.upsample2(x) # <--- no dynamic output_size
39+
return x
40+
41+
42+
def data_gen():
43+
yield [torch.rand(1, 3, 224, 224)]
44+
45+
46+
def test_convtranspose_dynamic_output_size(minimal_tpc):
47+
Model()(next(data_gen())[0])
48+
49+
fw_impl = PytorchImplementation()
50+
fw_info = DEFAULT_PYTORCH_INFO
51+
model = Model()
52+
53+
graph = graph_preparation_runner(model,
54+
data_gen,
55+
QuantizationConfig(),
56+
fw_info=fw_info,
57+
fw_impl=fw_impl,
58+
fqc=AttachTpcToPytorch().attach(minimal_tpc),
59+
mixed_precision_enable=False,
60+
running_gptq=False)
61+
62+
nodes = graph.get_topo_sorted_nodes()
63+
64+
assert nodes[2].framework_attr['output_padding'] == (1,1)
65+
assert nodes[2].output_shape[0] == [1, 3, 224, 224]
66+
assert nodes[4].framework_attr['output_padding'] == (0,0)
67+
assert nodes[4].output_shape[0] == [1, 3, 223, 223]

0 commit comments

Comments
 (0)