Skip to content

Commit eb96b71

Browse files
authored
Fix mac computation (#1363)
* fix mac computation * fix code accessing tpc.operator_set in case its None
1 parent 71354df commit eb96b71

File tree

8 files changed

+299
-41
lines changed

8 files changed

+299
-41
lines changed

model_compression_toolkit/core/common/graph/base_node.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -440,18 +440,28 @@ def has_any_weight_attr_to_quantize(self) -> bool:
440440

441441
return any([self.is_weights_quantization_enabled(attr) for attr in self.get_node_weights_attributes()])
442442

443-
def get_total_output_params(self) -> float:
443+
# TODO it makes more sense to standardize the input/output shapes at node creation.
444+
def get_output_shapes_list(self) -> List[tuple]:
444445
"""
445-
Calculates the output size of the node.
446+
Return output shape in a standardized form as a list of tuples.
446447
447-
Returns: Output size.
448+
Returns:
449+
A list of output shape tuples.
448450
"""
449451
# shape can be tuple or list, and multiple shapes can be packed in list or tuple
450452
if self.output_shape and isinstance(self.output_shape[0], (tuple, list)):
451-
output_shapes = self.output_shape
453+
output_shapes = [tuple(s) for s in self.output_shape]
452454
else:
453-
output_shapes = [self.output_shape]
455+
output_shapes = [tuple(self.output_shape)]
456+
return output_shapes
457+
458+
def get_total_output_params(self) -> float:
459+
"""
460+
Calculates the output size of the node.
454461
462+
Returns: Output size.
463+
"""
464+
output_shapes = self.get_output_shapes_list()
455465
# remove batch size (first element) from output shape
456466
output_shapes = [s[1:] for s in output_shapes]
457467
# for scalar shape (None,) prod returns 1
@@ -550,7 +560,7 @@ def has_activation_quantization_enabled_candidate(self) -> bool:
550560
"""
551561

552562
return len(self.candidates_quantization_cfg) > 0 and \
553-
any([c.activation_quantization_cfg.enable_activation_quantization for c in self.candidates_quantization_cfg])
563+
any([c.activation_quantization_cfg.enable_activation_quantization for c in self.candidates_quantization_cfg])
554564

555565
def get_all_weights_attr_candidates(self, attr: str) -> List[WeightsAttrQuantizationConfig]:
556566
"""

model_compression_toolkit/core/keras/keras_implementation.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -565,28 +565,24 @@ def get_node_mac_operations(self,
565565
566566
Returns: The MAC count og the operation
567567
"""
568-
569-
output_shape = node.output_shape
570-
kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
571-
output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
572-
573-
if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose):
574-
# (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
575-
return np.prod([x for x in output_shape if x is not None]) * \
576-
kernel_shape[input_channel_axis] * \
577-
(kernel_shape[0] * kernel_shape[1])
578-
elif node.is_match_type(DepthwiseConv2D):
579-
# Depth * (W_out * H_out) * C_in * (W_kernel * H_kernel)
580-
return node.framework_attr.get(DEPTH_MULTIPLIER) * \
581-
np.prod([x for x in output_shape if x is not None]) / output_shape[output_channel_axis] * \
582-
kernel_shape[input_channel_axis] * \
583-
(kernel_shape[0] * kernel_shape[1])
584-
elif node.is_match_type(Dense):
585-
# IN * OUT
586-
return kernel_shape[0] * kernel_shape[1]
587-
else:
568+
kernels = fw_info.get_kernel_op_attributes(node.type)
569+
if not kernels or kernels[0] is None:
588570
return 0
589571

572+
assert len(kernels) == 1
573+
kernel_shape = node.get_weights_by_keys(kernels[0]).shape
574+
575+
if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose) or node.is_match_type(DepthwiseConv2D):
576+
h, w = node.get_output_shapes_list()[0][-3:-1]
577+
return np.prod(kernel_shape) * h * w
578+
579+
if node.is_match_type(Dense):
580+
# IN * OUT * (all previous dims[:-1])
581+
_, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
582+
return node.get_total_output_params() * kernel_shape[input_channel_axis]
583+
584+
return 0
585+
590586
def apply_second_moment_correction(self,
591587
quantized_model: Any,
592588
core_config: CoreConfig,

model_compression_toolkit/core/pytorch/pytorch_implementation.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -506,21 +506,23 @@ def get_node_mac_operations(self,
506506
507507
Returns: The MAC count of the operation
508508
"""
509+
kernels = fw_info.get_kernel_op_attributes(node.type)
510+
if not kernels or kernels[0] is None:
511+
return 0
509512

510-
output_shape = node.output_shape[0]
511-
kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
512-
output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
513+
assert len(kernels) == 1
514+
kernel_shape = node.get_weights_by_keys(kernels[0]).shape
513515

514516
if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
515-
# (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
516-
return np.prod([x for x in output_shape if x is not None]) * \
517-
kernel_shape[input_channel_axis] * \
518-
(kernel_shape[0] * kernel_shape[1])
519-
elif node.is_match_type(Linear):
520-
# IN * OUT
521-
return kernel_shape[0] * kernel_shape[1]
522-
else:
523-
return 0
517+
h, w = node.get_output_shapes_list()[0][-2:]
518+
return np.prod(kernel_shape) * h * w
519+
520+
if node.is_match_type(Linear):
521+
# IN * OUT * (all previous dims[:-1])
522+
_, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
523+
return node.get_total_output_params() * kernel_shape[input_channel_axis]
524+
525+
return 0
524526

525527
def apply_second_moment_correction(self,
526528
quantized_model: Any,

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def attach(self, tpc_model: TargetPlatformCapabilities,
3939

4040
tpc = FrameworkQuantizationCapabilities(tpc_model)
4141
custom_opset2layer = custom_opset2layer if custom_opset2layer is not None else {}
42-
42+
operator_set = tpc_model.operator_set or ()
4343
with tpc:
44-
for opset in tpc_model.operator_set:
44+
for opset in operator_set:
4545
if isinstance(opset, OperatorsSet): # filter out OperatorsSetConcat
4646
if opset.name in custom_opset2layer:
4747
custom_opset_layers = custom_opset2layer[opset.name]

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(self,
5252
self.op_sets_to_layers = OperationsToLayers() # Init an empty OperationsToLayers
5353
self.layer2qco, self.filterlayer2qco = {}, {} # Init empty mappings from layers/LayerFilterParams to QC options
5454
# Track the unused opsets for warning purposes.
55-
self.__tpc_opsets_not_used = [s.name for s in tpc.operator_set]
55+
operator_set = tpc.operator_set or ()
56+
self.__tpc_opsets_not_used = [s.name for s in operator_set]
5657
self.remove_fusing_names_from_not_used_list()
5758

5859
def get_layers_by_opset_name(self, opset_name: str) -> List[Any]:

tests_pytest/conftest.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from mct_quantizers import QuantizationMethod
16+
from unittest.mock import Mock
17+
18+
from pytest import fixture
19+
20+
from model_compression_toolkit.core import FrameworkInfo, QuantizationConfig
21+
from model_compression_toolkit.core.common import Graph
22+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23+
from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig, Signedness, \
24+
QuantizationConfigOptions, TargetPlatformCapabilities
25+
26+
27+
@fixture
28+
def default_op_quant_cfg():
29+
return OpQuantizationConfig(
30+
default_weight_attr_config={},
31+
attr_weights_configs_mapping={},
32+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
33+
activation_n_bits=8,
34+
supported_input_activation_n_bits=[8],
35+
enable_activation_quantization=True,
36+
quantization_preserving=False,
37+
fixed_scale=None,
38+
fixed_zero_point=None,
39+
simd_size=32,
40+
signedness=Signedness.AUTO)
41+
42+
43+
@fixture
44+
def default_quant_cfg_options(default_op_quant_cfg):
45+
return QuantizationConfigOptions(quantization_configurations=[default_op_quant_cfg])
46+
47+
48+
@fixture
49+
def minimal_tpc(default_quant_cfg_options):
50+
return TargetPlatformCapabilities(default_qco=default_quant_cfg_options,
51+
tpc_platform_type='test',
52+
operator_set=None,
53+
fusing_patterns=None)
54+
55+
56+
@fixture
57+
def graph_mock():
58+
""" Basic Graph mock. """
59+
return Mock(spec_set=Graph, nodes=[])
60+
61+
62+
@fixture
63+
def fw_impl_mock():
64+
""" Basic FrameworkImplementation mock. """
65+
return Mock(spec_set=FrameworkImplementation)
66+
67+
68+
@fixture
69+
def fw_info_mock():
70+
""" Basic FrameworkInfo mock. """
71+
return Mock(spec_set=FrameworkInfo)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import numpy as np
16+
from keras.layers import Conv2D, Conv2DTranspose, DepthwiseConv2D, Dense, Input, Flatten
17+
import keras
18+
19+
from model_compression_toolkit.core import QuantizationConfig
20+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
21+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22+
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
23+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
24+
AttachTpcToKeras
25+
26+
27+
def data_gen():
28+
yield [np.random.randn(1, 28, 32, 10)]
29+
30+
31+
def build_model():
32+
x = Input(shape=(28, 32, 10))
33+
y = Conv2D(filters=20, kernel_size=(5, 4))(x)
34+
y = Conv2D(filters=15, kernel_size=(4, 6), groups=5)(y)
35+
y = Conv2D(filters=8, kernel_size=(3, 3), strides=2)(y)
36+
y = Conv2D(filters=12, kernel_size=(3, 3), dilation_rate=2)(y)
37+
y = Conv2DTranspose(filters=20, kernel_size=(5, 3))(y)
38+
y = Conv2DTranspose(filters=10, kernel_size=(3, 3), strides=2)(y)
39+
y = Conv2DTranspose(filters=5, kernel_size=(3, 3), dilation_rate=2)(y)
40+
y = DepthwiseConv2D(kernel_size=(2, 3), depth_multiplier=4)(y)
41+
y = DepthwiseConv2D(kernel_size=(3, 3), depth_multiplier=2, strides=3)(y)
42+
y = DepthwiseConv2D(kernel_size=(3, 3), depth_multiplier=2, dilation_rate=2)(y)
43+
y = Dense(10)(y) # 4d input
44+
y = Flatten()(y)
45+
y = Dense(5)(y) # 2d (vector) input
46+
return keras.Model(inputs=x, outputs=y)
47+
48+
49+
def test_get_mac(minimal_tpc):
50+
fw_impl = KerasImplementation()
51+
model = build_model()
52+
fw_info = DEFAULT_KERAS_INFO
53+
54+
graph = graph_preparation_runner(model,
55+
data_gen,
56+
QuantizationConfig(linear_collapsing=False),
57+
fw_info=fw_info,
58+
fw_impl=fw_impl,
59+
fqc=AttachTpcToKeras().attach(minimal_tpc),
60+
mixed_precision_enable=False,
61+
running_gptq=False)
62+
63+
nodes = graph.get_topo_sorted_nodes()
64+
assert len(nodes) == 14, nodes
65+
assert fw_impl.get_node_mac_operations(nodes[0], fw_info) == 0
66+
assert fw_impl.get_node_mac_operations(nodes[1], fw_info) == (10*20*5*4)*24*29
67+
assert fw_impl.get_node_mac_operations(nodes[2], fw_info) == (4*3*4*6)*5*21*24
68+
assert fw_impl.get_node_mac_operations(nodes[3], fw_info) == (15*8*3*3)*10*11
69+
assert fw_impl.get_node_mac_operations(nodes[4], fw_info) == (8*12*3*3)*6*7
70+
assert fw_impl.get_node_mac_operations(nodes[5], fw_info) == (12*20*5*3)*10*9
71+
assert fw_impl.get_node_mac_operations(nodes[6], fw_info) == (20*10*3*3)*21*19
72+
assert fw_impl.get_node_mac_operations(nodes[7], fw_info) == (10*5*3*3)*25*23
73+
assert fw_impl.get_node_mac_operations(nodes[8], fw_info) == (5*2*3*4)*24*21
74+
assert fw_impl.get_node_mac_operations(nodes[9], fw_info) == (10*3*3*4)*8*7
75+
assert fw_impl.get_node_mac_operations(nodes[10], fw_info) == (40*3*3*2)*4*3
76+
assert fw_impl.get_node_mac_operations(nodes[11], fw_info) == 4*3*(80*10)
77+
assert fw_impl.get_node_mac_operations(nodes[12], fw_info) == 0
78+
assert fw_impl.get_node_mac_operations(nodes[13], fw_info) == (4*3*10)*5
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import torch
16+
from torch import nn
17+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
18+
AttachTpcToPytorch
19+
20+
from model_compression_toolkit.core import QuantizationConfig
21+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
22+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
23+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
24+
25+
26+
def data_gen():
27+
yield [torch.rand(1, 10, 28, 32)]
28+
29+
30+
class Model(nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
self.conv1 = nn.Conv2d(10, 20, kernel_size=(5, 4))
34+
self.conv2 = nn.Conv2d(20, 15, kernel_size=(4, 6), groups=5)
35+
self.conv3 = nn.Conv2d(15, 8, kernel_size=(3, 3), stride=2)
36+
self.conv4 = nn.Conv2d(8, 12, kernel_size=(3, 3), dilation=2)
37+
self.convtr1 = nn.ConvTranspose2d(12, 20, kernel_size=(5, 3))
38+
self.convtr2 = nn.ConvTranspose2d(20, 10, kernel_size=(3, 3), stride=2)
39+
self.convtr3 = nn.ConvTranspose2d(10, 5, kernel_size=(3, 3), dilation=2)
40+
self.dwconv1 = nn.Conv2d(5, 20, kernel_size=(2, 3), groups=5)
41+
self.dwconv2 = nn.Conv2d(20, 40, kernel_size=(3, 3), groups=20, stride=3)
42+
self.dwconv3 = nn.Conv2d(40, 80, kernel_size=(3, 3), groups=40, dilation=2)
43+
self.fc1 = nn.Linear(80, 10)
44+
self.flatten = nn.Flatten()
45+
self.fc2 = nn.Linear(120, 5)
46+
47+
def forward(self, x):
48+
x = self.conv1(x)
49+
x = self.conv2(x)
50+
x = self.conv3(x)
51+
x = self.conv4(x)
52+
x = self.convtr1(x)
53+
x = self.convtr2(x)
54+
x = self.convtr3(x)
55+
x = self.dwconv1(x)
56+
x = self.dwconv2(x)
57+
x = self.dwconv3(x)
58+
x = torch.permute(x, [0, 2, 3, 1])
59+
x = self.fc1(x)
60+
x = self.flatten(x)
61+
x = self.fc2(x)
62+
return x
63+
64+
65+
def test_get_mac(minimal_tpc):
66+
Model()(next(data_gen())[0])
67+
68+
fw_impl = PytorchImplementation()
69+
fw_info = DEFAULT_PYTORCH_INFO
70+
model = Model()
71+
72+
graph = graph_preparation_runner(model,
73+
data_gen,
74+
QuantizationConfig(linear_collapsing=False),
75+
fw_info=fw_info,
76+
fw_impl=fw_impl,
77+
fqc=AttachTpcToPytorch().attach(minimal_tpc),
78+
mixed_precision_enable=False,
79+
running_gptq=False)
80+
81+
nodes = graph.get_topo_sorted_nodes()
82+
# assert len(nodes) == 14, nodes
83+
assert fw_impl.get_node_mac_operations(nodes[0], fw_info) == 0
84+
assert fw_impl.get_node_mac_operations(nodes[1], fw_info) == (10*20*5*4)*24*29
85+
assert fw_impl.get_node_mac_operations(nodes[2], fw_info) == (4*3*4*6)*5*21*24
86+
assert fw_impl.get_node_mac_operations(nodes[3], fw_info) == (15*8*3*3)*10*11
87+
assert fw_impl.get_node_mac_operations(nodes[4], fw_info) == (8*12*3*3)*6*7
88+
assert fw_impl.get_node_mac_operations(nodes[5], fw_info) == (12*20*5*3)*10*9
89+
assert fw_impl.get_node_mac_operations(nodes[6], fw_info) == (20*10*3*3)*21*19
90+
assert fw_impl.get_node_mac_operations(nodes[7], fw_info) == (10*5*3*3)*25*23
91+
assert fw_impl.get_node_mac_operations(nodes[8], fw_info) == (5*2*3*4)*24*21
92+
assert fw_impl.get_node_mac_operations(nodes[9], fw_info) == (10*3*3*4)*8*7
93+
assert fw_impl.get_node_mac_operations(nodes[10], fw_info) == (40*3*3*2)*4*3
94+
assert fw_impl.get_node_mac_operations(nodes[10], fw_info) == (40*3*3*2)*4*3
95+
assert fw_impl.get_node_mac_operations(nodes[11], fw_info) == 0
96+
assert fw_impl.get_node_mac_operations(nodes[12], fw_info) == 4*3*(80*10)
97+
assert fw_impl.get_node_mac_operations(nodes[13], fw_info) == 0
98+
assert fw_impl.get_node_mac_operations(nodes[14], fw_info) == (4*3*10)*5
99+
100+

0 commit comments

Comments
 (0)