Skip to content

Commit 07ac2d4

Browse files
committed
debugging version(040700gouda)
1 parent 29f6283 commit 07ac2d4

File tree

3 files changed

+278
-4
lines changed

3 files changed

+278
-4
lines changed

model_compression_toolkit/core/common/quantization/bit_width_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,10 @@ def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
209209
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
210210
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
211211

212+
print('ffff', filtered_nodes)
212213
for n in filtered_nodes:
213214
attr_to_change_bit_width = []
214-
215+
print('nnnnnnn', n, n.get_node_weights_attributes())
215216
attrs_str = n.get_node_weights_attributes()
216217
if len(attrs_str) == 0:
217218
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')

model_compression_toolkit/core/common/quantization/set_node_quantization_config.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,13 @@ def set_quantization_configuration_to_graph(graph: Graph,
6969

7070
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
7171
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
72+
print('nodes_to_manipulate_weights_bit_widths:', nodes_to_manipulate_weights_bit_widths)
7273

7374
for n in graph.nodes:
75+
print('n', n)
7476
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
7577
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
78+
print('manual_bit_width_override', manual_bit_width_override)
7679
set_quantization_configs_to_node(node=n,
7780
graph=graph,
7881
quant_config=quant_config,
@@ -169,7 +172,13 @@ def set_quantization_configs_to_node(node: BaseNode,
169172
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
170173
"""
171174
node_qc_options = node.get_qco(fqc)
175+
print('8 node_qc_options', node_qc_options)
176+
print('8', type(node_qc_options.quantization_configurations))
177+
#print('8', node_qc_options.quantization_configurations[0].attr_weights_configs_mapping['kernel_attr'].weights_n_bits)
172178
base_config, node_qc_options_list = filter_node_qco_by_graph(node, fqc, graph, node_qc_options)
179+
print('9 base_config', base_config)
180+
print('9 base_config.attr_weights_configs_mapping', base_config.attr_weights_configs_mapping)
181+
print('9 node_qc_options_list', len(node_qc_options_list), node_qc_options_list)
173182

174183
# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
175184
# and update base_config accordingly.
@@ -192,7 +201,7 @@ def set_quantization_configs_to_node(node: BaseNode,
192201
base_config,
193202
node,
194203
mixed_precision_enable=mixed_precision_enable)
195-
204+
#print('node.candidates_quantization_cfg', node.candidates_quantization_cfg)
196205
# sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits
197206
# (in reversed order). since only kernel attribute is quantized in weights mixed precision,
198207
# if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
@@ -430,22 +439,30 @@ def filter_weights_qc_options_with_manual_bit_width(
430439
node_qc_options_weights_list = []
431440
override_attr, override_bitwidth = [], []
432441

442+
print('weights_manual_bit_width_override', weights_manual_bit_width_override, type(weights_manual_bit_width_override))
433443
for weights_manual_bit_width in weights_manual_bit_width_override:
434444
override_attr.append(weights_manual_bit_width[1])
435445
override_bitwidth.append(weights_manual_bit_width[0])
436446

447+
print("zzz", override_attr, override_bitwidth)
437448
node_qc_options_weights_list = copy.deepcopy(node_qc_options_list)
449+
print('yyy', node_qc_options_weights_list)
438450
for attr, bitwidth in zip(override_attr, override_bitwidth):
451+
print("aza", attr, bitwidth)
439452
for op_cfg in node_qc_options_list:
453+
print("bbb", op_cfg)
440454
if op_cfg in node_qc_options_weights_list:
441455
weights_attrs = op_cfg.attr_weights_configs_mapping.keys()
456+
print('weights_attrs', weights_attrs)
442457
if attr in weights_attrs:
458+
print('chk', attr, weights_attrs)
443459
for weights_attr in weights_attrs:
444460
if attr == weights_attr and op_cfg.attr_weights_configs_mapping.get(attr).weights_n_bits != bitwidth:
445461
node_qc_options_weights_list.remove(op_cfg)
446462
else:
447463
node_qc_options_weights_list.remove(op_cfg)
448-
464+
print('node_qc_options_weights_list', node_qc_options_weights_list)
465+
print('0 base_config.attr_weights_configs_mapping', base_config.attr_weights_configs_mapping)
449466
if len(node_qc_options_weights_list) == 0:
450467
Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
451468
else:
@@ -455,19 +472,24 @@ def filter_weights_qc_options_with_manual_bit_width(
455472
updated_base_config = copy.deepcopy(base_config)
456473
for attr, bitwidth in zip(override_attr, override_bitwidth):
457474
Logger.info(f"Setting node {node} bit-width to manually selected {attr} bit-width: {bitwidth} bits.")
458-
updated_base_config = updated_base_config.clone_and_edit(attr_to_edit={attr : {WEIGHTS_N_BITS: bitwidth}})
459475

476+
print('node attr, bitwidth', node, attr, bitwidth)
477+
updated_base_config = updated_base_config.clone_and_edit(attr_to_edit={attr : {WEIGHTS_N_BITS: bitwidth}})
478+
print('updated_base_config', updated_base_config)
460479
if updated_base_config in node_qc_options_weights_list:
461480
# If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
462481
# point the base_config to this option.
482+
print('chk0000')
463483
base_config = node_qc_options_weights_list[node_qc_options_weights_list.index(updated_base_config)]
464484
else:
465485
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
486+
print('chk0001')
466487
base_config = node_qc_options_weights_list[0]
467488
if len(node_qc_options_weights_list) > 0 and not mixed_precision_enable:
468489
Logger.info(
469490
f"Request received to select weights bit-widths {weights_manual_bit_width_override}."
470491
f"However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
471492
f" Overriding base_config with an option that uses manually selected weights bit-widths {weights_manual_bit_width_override}.") # pragma: no cover
493+
print('1 base_config', base_config)
472494

473495
return base_config, node_qc_options_weights_list
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import pytest
16+
17+
import model_compression_toolkit as mct
18+
from model_compression_toolkit.constants import PYTORCH
19+
from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter
20+
from model_compression_toolkit.core.common.quantization.bit_width_config import ManualBitWidthSelection, ManualWeightsBitWidthSelection
21+
from model_compression_toolkit.core import BitWidthConfig, CoreConfig
22+
23+
from model_compression_toolkit.core.common import Graph
24+
from model_compression_toolkit.core.common.graph.edge import Edge
25+
from tests_pytest._test_util.graph_builder_utils import build_node
26+
27+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \
28+
set_quantization_configuration_to_graph
29+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import \
30+
FrameworkQuantizationCapabilities, OperationsSetToLayers
31+
32+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
33+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
34+
35+
import torch
36+
from torch import nn
37+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
38+
AttachTpcToPytorch
39+
40+
from model_compression_toolkit.core import QuantizationConfig
41+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
42+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
43+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
44+
45+
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs, generate_tpc
46+
47+
48+
#TEST_KERNEL = 'kernel'
49+
#TEST_BIAS = 'bias'
50+
51+
### dummy layer classes
52+
class Conv2D:
53+
pass
54+
class InputLayer:
55+
pass
56+
class Add:
57+
pass
58+
class BatchNormalization:
59+
pass
60+
class ReLU:
61+
pass
62+
class Flatten:
63+
pass
64+
class Dense:
65+
pass
66+
67+
#from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs
68+
69+
from tests.pytorch_tests.tpc_pytorch import get_mp_activation_pytorch_tpc_dict
70+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
71+
72+
from tests.common_tests.helpers.generate_test_tpc import generate_tpc_with_activation_mp
73+
74+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
75+
76+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS
77+
def get_tpc(kernel_n, bias_n):
78+
#kernel_weights_n_bits = 8 ### [DEBUG0404] 8 ni suruto Error. 16 dato ugoku.
79+
#bias_weights_n_bits = 32
80+
#activation_n_bits = 8
81+
82+
base_cfg, _, default_config = get_op_quantization_configs()
83+
"""
84+
base_config = base_cfg.clone_and_edit(attr_weights_configs_mapping=
85+
{
86+
KERNEL_ATTR: base_cfg.attr_weights_configs_mapping[KERNEL_ATTR]
87+
.clone_and_edit(weights_n_bits=kernel_weights_n_bits),
88+
BIAS_ATTR: base_cfg.attr_weights_configs_mapping[BIAS_ATTR]
89+
.clone_and_edit(weights_n_bits=bias_weights_n_bits, enable_weights_quantization=True),
90+
},
91+
activation_n_bits=activation_n_bits)
92+
"""
93+
weights_04_bits = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}})
94+
weights_02_bits = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 2}})
95+
weights_16_bits = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 16}})
96+
97+
mx_cfg_list = [base_cfg, weights_04_bits, weights_02_bits, weights_16_bits]
98+
tpc = generate_tpc(default_config, base_cfg, mx_cfg_list, 'imx500_tpc_kai')
99+
100+
return tpc
101+
102+
103+
# AttributeQuantizationConfig(weights_quantization_method=<QuantizationMethod.SYMMETRIC: 2>, weights_n_bits=16, weights_per_channel_threshold=True, enable_weights_quantization=True, lut_values_bitwidth=None)
104+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
105+
from tests.common_tests.helpers.generate_test_tpc import generate_test_tpc
106+
107+
### test model
108+
def get_test_graph(kernel_n, bias_n):
109+
n1 = build_node('input', layer_class=InputLayer)
110+
conv1 = build_node('conv1', layer_class=Conv2D,
111+
canonical_weights={
112+
KERNEL_ATTR: AttributeQuantizationConfig(weights_n_bits=8),
113+
BIAS_ATTR: AttributeQuantizationConfig(weights_n_bits=32)}
114+
)
115+
add1 = build_node('add1', layer_class=Add)
116+
conv2 = build_node('conv2', layer_class=Conv2D,
117+
canonical_weights={
118+
KERNEL_ATTR: AttributeQuantizationConfig(weights_n_bits=8),
119+
BIAS_ATTR: AttributeQuantizationConfig(weights_n_bits=32)}
120+
)
121+
bn1 = build_node('bn1', layer_class=BatchNormalization)
122+
relu = build_node('relu1', layer_class=ReLU,
123+
canonical_weights={
124+
KERNEL_ATTR: AttributeQuantizationConfig(weights_n_bits=8),
125+
BIAS_ATTR: AttributeQuantizationConfig(weights_n_bits=32)}
126+
)
127+
add2 = build_node('add2', layer_class=Add)
128+
flatten = build_node('flatten', layer_class=Flatten)
129+
fc = build_node('fc', layer_class=Dense)
130+
131+
graph = Graph('xyz', input_nodes=[n1],
132+
nodes=[conv1,add1, conv2, bn1, relu, add2, flatten],
133+
output_nodes=[fc],
134+
edge_list=[Edge(n1, conv1, 0, 0),
135+
Edge(conv1, add1, 0, 0),
136+
Edge(add1, conv2, 0, 0),
137+
Edge(conv2, bn1, 0, 0),
138+
Edge(bn1, relu, 0, 0),
139+
Edge(relu, add2, 0, 0),
140+
Edge(add1, add2, 0, 0),
141+
Edge(add2, flatten, 0, 0),
142+
Edge(flatten, fc, 0, 0),
143+
]
144+
)
145+
146+
tpc = get_tpc(kernel_n, bias_n)
147+
#tpc = mct.get_target_platform_capabilities('pytorch', 'default')
148+
#print('tpc', tpc)
149+
#print(type(tpc))
150+
#for val in tpc:
151+
# print(val)
152+
#print('a'+1)
153+
fqc = FrameworkQuantizationCapabilities(tpc)
154+
graph.set_fqc(fqc)
155+
156+
fw_info = DEFAULT_PYTORCH_INFO
157+
graph.set_fw_info(fw_info)
158+
return graph
159+
160+
class TestManualWeightsBitwidthSelection:
161+
# test case for set_manual_activation_bit_width
162+
test_input_1 = (NodeTypeFilter(Conv2D), 8, KERNEL_ATTR)
163+
test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16], [KERNEL_ATTR])
164+
test_input_3 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [4, 8], [KERNEL_ATTR, BIAS_ATTR])
165+
166+
test_expected_1 = (NodeTypeFilter, ReLU, 16)
167+
test_expected_2 = ([NodeTypeFilter, ReLU, 2], [NodeNameFilter, "conv1", 2])
168+
test_expected_3 = ([NodeTypeFilter, ReLU, 4], [NodeNameFilter, "conv1", 8])
169+
170+
@pytest.mark.parametrize(("inputs", "expected"), [
171+
(test_input_1, test_expected_1),
172+
#(test_input_2, test_expected_2),
173+
#(test_input_3, test_expected_3),
174+
])
175+
def test_manual_weights_bitwidth_selection(self, inputs, expected):
176+
print('# test_manual_weights_bitwidth_selection start.')
177+
178+
print('inputs', inputs)
179+
print('expected', expected)
180+
181+
kernel_n = 8
182+
bias_n = 32
183+
if KERNEL_ATTR in inputs[2]:
184+
indices = [index for index, value in enumerate(inputs[2]) if value == KERNEL_ATTR]
185+
kernel_n = inputs[1] if type(inputs[2]) != list else inputs[1][indices[0]]
186+
if BIAS_ATTR in inputs[2]:
187+
indices = [index for index, value in enumerate(inputs[2]) if value == BIAS_ATTR]
188+
bias_n = inputs[1] if type(inputs[2]) != list else inputs[1][indices[0]]
189+
print('kernel_n, bias_n', kernel_n, bias_n)
190+
graph = get_test_graph(kernel_n, bias_n)
191+
#graph = get_test_graph()
192+
print('graph', graph)
193+
core_config = CoreConfig()
194+
195+
core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2])
196+
197+
updated_graph = set_quantization_configuration_to_graph(
198+
graph, core_config.quantization_config, core_config.bit_width_config,
199+
False, False
200+
)
201+
print('------graph---------------------')
202+
print('0', graph)
203+
print('1', graph.nodes)
204+
print('2', graph.nodes.keys())
205+
"""
206+
for n in graph.nodes:
207+
print('n', n)
208+
a = graph.get_weights_configurable_nodes(DEFAULT_PYTORCH_INFO, True)
209+
b = graph.get_activation_configurable_nodes()
210+
print('a', a)
211+
print('b', b)
212+
213+
### len(node.candidates_quantization_cfg) de Error.
214+
for node in updated_graph.nodes:
215+
print("z", node) #, node.candidates_quantization_cfg
216+
for ii in range(len(node.candidates_quantization_cfg)):
217+
print('z4', ii, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping)
218+
#print('z4 0', ii, type(node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping))
219+
for vkey in node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping:
220+
#print('z5', vkey, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey])
221+
cfg = node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey]
222+
print('z5 cfg.weights_n_bits', cfg.weights_n_bits)
223+
"""
224+
225+
print('------updated graph---------------------')
226+
print(updated_graph)
227+
228+
for node in updated_graph.nodes:
229+
print("z", node) #, node.candidates_quantization_cfg
230+
for ii in range(len(node.candidates_quantization_cfg)):
231+
print('z4', ii, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping)
232+
#print('z4 0', ii, type(node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping))
233+
for vkey in node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping:
234+
#print('z5', vkey, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey])
235+
cfg = node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey]
236+
print('z5 cfg.weights_n_bits', cfg.weights_n_bits)
237+
238+
"""
239+
for val2 in node.get_node_weights_attributes():
240+
print("z2", val2, type(val2))
241+
a = node.weights[val2]
242+
print('a', a)
243+
"""
244+
245+
#assert graph == updated_graph
246+
247+
248+
print('# test_manual_weights_bitwidth_selection end.')
249+
250+
pass
251+

0 commit comments

Comments
 (0)