Skip to content

Commit 6d97488

Browse files
committed
add dynamic padding substitution for nn.ConvTranspose2d
1 parent 7bd8f49 commit 6d97488

File tree

5 files changed

+146
-2
lines changed

5 files changed

+146
-2
lines changed

model_compression_toolkit/core/pytorch/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DILATIONS = 'dilation'
3434
TENSOR_META = 'tensor_meta'
3535
FILTERS = 'out_channels'
36+
OUTPUT_PADDING = 'output_padding'
3637
TYPE = 'type'
3738
PAD = 'pad'
3839
VALUE = 'value'
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from typing import Tuple
16+
import torch.nn as nn
17+
import torch
18+
from model_compression_toolkit.core.pytorch.constants import OUTPUT_PADDING
19+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
20+
from model_compression_toolkit.core import common
21+
from model_compression_toolkit.core.common import BaseNode, Graph
22+
from model_compression_toolkit.logger import Logger
23+
24+
25+
class ConvtransposeDynamicPadding(common.BaseSubstitution):
26+
"""
27+
Replace output_padding of nn.convtranspose to align dynamic output_size input
28+
"""
29+
30+
def __init__(self):
31+
"""
32+
Matches: functional batch_norm
33+
"""
34+
convtr_node = NodeOperationMatcher(nn.ConvTranspose2d)
35+
super().__init__(matcher_instance=convtr_node)
36+
37+
38+
def calc_dynamic_output_size(self, node: BaseNode) -> Tuple[int]:
39+
"""
40+
Calc the output padding to support dunamic output_size of nn.ConvTranspose2d
41+
Args:
42+
node: node to calculate output padding
43+
44+
Returns:
45+
correct output padding
46+
"""
47+
convtr = nn.ConvTranspose2d(**node.framework_attr)
48+
num_spatial_dims = 2
49+
output_padding = convtr._output_padding(torch.randn(size=node.input_shape[0]),
50+
node.output_shape[0],
51+
convtr.stride,
52+
convtr.padding,
53+
convtr.kernel_size,
54+
num_spatial_dims,
55+
convtr.dilation)
56+
return tuple(output_padding)
57+
58+
59+
def substitute(self,
60+
graph: Graph,
61+
node: BaseNode) -> Graph:
62+
"""
63+
Substitute functional.batch_norm and its inputs with BatchNorm2d.
64+
Args:
65+
graph: Graph we apply the substitution on.
66+
node: node that match the pattern in the substitution init.
67+
68+
Returns:
69+
Graph after applying the substitution.
70+
"""
71+
# Check that output is only contain a single tensor
72+
if len(node.output_shape) > 1:
73+
Logger.critical('Output to nn.ConvTranspose2d should be a single tensor but got more than one.') # pragma: no cover
74+
output_padding = self.calc_dynamic_output_size(node)
75+
node.framework_attr.update({OUTPUT_PADDING: output_padding})
76+
return graph

model_compression_toolkit/core/pytorch/pytorch_implementation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
ScaledDotProductDecomposition
6363
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.transform_function_call_method import \
6464
TransformFunctionCallMethod
65+
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.convtranspose_dynamic_padding import \
66+
ConvtransposeDynamicPadding
6567
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \
6668
FunctionalConvSubstitution
6769
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
@@ -286,7 +288,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
286288
FunctionalBatchNorm(),
287289
FunctionalLayerNorm(),
288290
FunctionalLinear(),
289-
RemoveIdentity()]
291+
RemoveIdentity(),
292+
ConvtransposeDynamicPadding()]
290293

291294
def get_substitutions_pre_statistics_collection(self,
292295
quant_config: QuantizationConfig
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import torch
16+
import torch.nn as nn
17+
from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest
18+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
19+
20+
"""
21+
This test checks dynamic output size for nn.ConvTranspose2d.
22+
"""
23+
24+
25+
class ConvTranspose2dDynamicNet(nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.downsample = nn.Conv2d(3, 16, 3, stride=2, padding=1)
29+
self.upsample = nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1)
30+
31+
def forward(self, x):
32+
x = self.downsample(x)
33+
x = self.upsample(x, output_size=[224, 224]) # <--- dynamic output_size
34+
return x
35+
36+
37+
class ConvTranspose2dDynamicNetTest(BasePytorchTest):
38+
"""
39+
This test checks the addition and subtraction operations.
40+
Both with different layers and with constants.
41+
"""
42+
43+
def __init__(self, unit_test):
44+
super().__init__(unit_test)
45+
46+
def create_inputs_shape(self):
47+
return [[self.val_batch_size, 3, 224, 224]]
48+
49+
def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):
50+
in_torch_tensor = to_torch_tensor(input_x[0])
51+
for _, qmodel in quantized_models.items():
52+
y_float = float_model(in_torch_tensor)
53+
y_quant = qmodel(in_torch_tensor)
54+
self.unit_test.assertTrue(y_float.shape == y_quant.shape,
55+
msg=f'Out shape of the quantized model is not as the float model!')
56+
57+
def create_feature_network(self, input_shape):
58+
return ConvTranspose2dDynamicNet()

tests/pytorch_tests/model_tests/test_feature_models_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from tests.pytorch_tests.model_tests.feature_models.constant_conv_substitution_test import ConstantConvSubstitutionTest, \
4848
ConstantConvReuseSubstitutionTest, ConstantConvTransposeSubstitutionTest
4949
from tests.pytorch_tests.model_tests.feature_models.conv2d_replacement_test import DwConv2dReplacementTest
50+
from tests.pytorch_tests.model_tests.feature_models.convtranspose_dynamic_output_size_test import \
51+
ConvTranspose2dDynamicNetTest
5052
from tests.pytorch_tests.model_tests.feature_models.dynamic_size_inputs_test import ReshapeNetTest
5153
from tests.pytorch_tests.model_tests.feature_models.gptq_test import GPTQAccuracyTest, GPTQWeightsUpdateTest, \
5254
GPTQLearnRateZeroTest
@@ -877,7 +879,11 @@ def test_manual_bit_width_selection_by_layer_name(self):
877879
ManualBitWidthByLayerNameTest(self, [NodeNameFilter('add'), NodeNameFilter('conv1_bn')], [2, 4]).run_test()
878880
ManualBitWidthByLayerNameTest(self, [NodeNameFilter('add'), NodeNameFilter('conv1_bn')], 4).run_test()
879881

880-
882+
def test_convtranspose_dynamic_output_size(self):
883+
"""
884+
This tests checks nn.ConvTranspose2d substitution for dynamic output size
885+
"""
886+
ConvTranspose2dDynamicNetTest(self).run_test()
881887

882888

883889
if __name__ == '__main__':

0 commit comments

Comments
 (0)