Skip to content

Commit 6345f8d

Browse files
reuvenperetzreuvenp
andauthored
Integrate max cut on FLNs in mixed-precision (#1427)
Integrate max cut on FLNs in mixed-precision RU computation. Change the calculation of RU to be on the original graph instead of the virtual graph in case of BOPs computation (this is equivalent, but needed for when BOPs and activations are restricted). --------- Co-authored-by: reuvenp <reuvenp@altair-semi.com>
1 parent e70c359 commit 6345f8d

File tree

9 files changed

+162
-32
lines changed

9 files changed

+162
-32
lines changed

model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ def __init__(self, model_graph: Graph):
3838
Args:
3939
model_graph: A graph representation of a model.
4040
"""
41-
42-
self.model_graph = model_graph
43-
4441
nodes = list(model_graph.nodes)
4542
memory_tensors = []
4643
node_to_tensor = []

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ def __init__(self,
7878
self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
7979

8080
self.ru_targets = target_resource_utilization.get_restricted_targets()
81-
self.ru_helper = MixedPrecisionRUHelper(self.mp_graph, fw_info, fw_impl)
81+
self.ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
8282

8383
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
8484
self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info)
85-
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
8685

8786
self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.mp_graph,
8887
original_graph=self.original_graph)
88+
if self.using_virtual_graph:
89+
real_min_ru_config: Dict[BaseNode, int] = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(self.min_ru_config)
90+
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, real_min_ru_config)
91+
else:
92+
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
8993

9094
def search(self) -> Dict[BaseNode, int]:
9195
"""
@@ -251,7 +255,8 @@ def _compute_relative_ru_matrices(self) -> Dict[RUTarget, np.ndarray]:
251255
else:
252256
cfg = self.min_ru_config.copy()
253257
cfg[node] = candidate_idx
254-
candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, cfg)
258+
real_cfg = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(cfg)
259+
candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, real_cfg)
255260

256261
for target, ru in candidate_rus.items():
257262
rus_per_candidate[target].append(ru)

model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from enum import Enum, auto
1919
from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence
2020

21+
from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
22+
2123
from model_compression_toolkit.constants import FLOAT_BITWIDTH
2224
from model_compression_toolkit.core import FrameworkInfo
2325
from model_compression_toolkit.core.common import Graph, BaseNode
@@ -145,8 +147,14 @@ def cuts(self) -> Dict[Cut, List[BaseNode]]:
145147
raise RuntimeError("Failed to calculate activation memory cuts for graph.")
146148
cuts = [cut for cut in cuts if cut.mem_elements.elements]
147149
# cache cuts nodes for future use, so do not filter by target
148-
self._cuts = {cut: [self.graph.find_node_by_name(m.node_name)[0] for m in cut.mem_elements.elements]
149-
for cut in cuts}
150+
self._cuts = {
151+
cut: [
152+
node
153+
for m in cut.mem_elements.elements
154+
for node in (self.graph.fusing_info.get_fused_nodes(m.node_name) or (self.graph.find_node_by_name(m.node_name)[0],))
155+
]
156+
for cut in cuts
157+
}
150158
return self._cuts
151159

152160
def compute_resource_utilization(self,
@@ -580,7 +588,9 @@ def compute_node_bops(self,
580588

581589
def _compute_cuts(self):
582590
""" Compute activation cuts of the graph. """
583-
memory_graph = MemoryGraph(deepcopy(self.graph))
591+
# Compute memory graph on fused graph with fused nodes
592+
graph = GraphFuser().apply_node_fusion(self.graph)
593+
memory_graph = MemoryGraph(deepcopy(graph))
584594
_, _, cuts = compute_graph_max_cut(memory_graph)
585595
return cuts
586596

tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
import math
1516
import typing
1617

1718
import abc
@@ -53,7 +54,8 @@ def get_base_mp_nbits_candidates():
5354
class MixedPrecisionActivationBaseTest(BaseKerasFeatureNetworkTest):
5455
def __init__(self, unit_test, activation_layers_idx, num_calibration_iter=1):
5556
super().__init__(unit_test, num_calibration_iter=num_calibration_iter)
56-
57+
# for the model that is used here, the two last tensors compose the max cut
58+
self.max_cut = 10 * 10 * 32 + 13 * 13 * 32
5759
self.activation_layers_idx = activation_layers_idx
5860

5961
def get_core_config(self):
@@ -135,14 +137,15 @@ def __init__(self, unit_test):
135137
super().__init__(unit_test, activation_layers_idx=[1, 2, 4])
136138

137139
def get_resource_utilization(self):
138-
return ResourceUtilization(weights_memory=17919, activation_memory=5407)
140+
return ResourceUtilization(weights_memory=17919, activation_memory=self.max_cut-1)
139141

140142
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
141143
# verify chosen activation bitwidth config
142144
# resource utilization is infinity -> should give best model - 8bits
143145
holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder)
144146
activation_bits = [layer.activation_holder_quantizer.get_config()['num_bits'] for layer in holder_layers]
145-
self.unit_test.assertTrue((activation_bits == [8, 4, 8]))
147+
# Since the max cut is the last two tensors, one of them have to get 4 bits
148+
self.unit_test.assertIn(activation_bits, ([8, 4, 8], [8, 8, 4]))
146149

147150
self.verify_quantization(quantized_model, input_x,
148151
weights_layers_idx=[2, 3],
@@ -157,7 +160,7 @@ def __init__(self, unit_test):
157160

158161
def get_resource_utilization(self):
159162
# resource utilization is for 4 bits on average
160-
return ResourceUtilization(weights_memory=17920 * 4 / 8, activation_memory=4300)
163+
return ResourceUtilization(weights_memory=17920 * 4 / 8, activation_memory=math.ceil(self.max_cut*4/8))
161164

162165
def get_tpc(self):
163166
eight_bits = generate_test_op_qc(**generate_test_attr_configs())
@@ -180,7 +183,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
180183
# then there is no guarantee that the activation bitwidth for each layer would be 4-bit,
181184
# this assertion tests the expected result for this specific
182185
# test with its current setup (therefore, we don't check the input layer's bitwidth)
183-
self.unit_test.assertTrue((activation_bits == [4, 8]))
186+
self.unit_test.assertTrue((activation_bits == [4, 4]))
184187

185188

186189
class MixedPrecisionActivationSearch2BitsAvgTest(MixedPrecisionActivationBaseTest):
@@ -189,7 +192,7 @@ def __init__(self, unit_test):
189192

190193
def get_resource_utilization(self):
191194
# resource utilization is for 2 bits on average
192-
return ResourceUtilization(weights_memory=17920.0 * 2 / 8, activation_memory=1544)
195+
return ResourceUtilization(weights_memory=17920.0 * 2 / 8, activation_memory=math.ceil(self.max_cut * 2 / 8))
193196

194197
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
195198
# verify chosen activation bitwidth config
@@ -213,7 +216,8 @@ def __init__(self, unit_test):
213216
super().__init__(unit_test, activation_layers_idx=[1, 3])
214217

215218
def get_resource_utilization(self):
216-
return ResourceUtilization(47, 767)
219+
# 638 = round_up((16*16*3+13*13*3)/2) -> so it must choose (4,4)
220+
return ResourceUtilization(47, 638)
217221

218222
def create_networks(self):
219223
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
@@ -225,18 +229,17 @@ def create_networks(self):
225229

226230
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
227231
# verify chosen activation bitwidth config
228-
# resource utilization is infinity -> should give best model - 8bits
229232
holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder)
230233
activation_bits = [layer.activation_holder_quantizer.get_config()['num_bits'] for layer in holder_layers]
231-
self.unit_test.assertTrue((activation_bits == [4, 8]))
234+
self.unit_test.assertTrue((activation_bits == [4, 4]))
232235

233236

234237
class MixedPrecisionActivationDepthwise4BitTest(MixedPrecisionActivationBaseTest):
235238
def __init__(self, unit_test):
236239
super().__init__(unit_test, activation_layers_idx=[1])
237240

238241
def get_resource_utilization(self):
239-
return ResourceUtilization(48.0 * 4 / 8, 768.0 * 4 / 8)
242+
return ResourceUtilization(48.0 * 4 / 8, math.ceil((16*16*3+13*13*3) * 4 / 8))
240243

241244
def get_tpc(self):
242245
eight_bits = generate_test_op_qc(**generate_test_attr_configs())
@@ -464,7 +467,7 @@ def __init__(self, unit_test):
464467

465468
def get_resource_utilization(self):
466469
# 17920: 8-bit weights, 6176: max cut of input+conv_bn
467-
return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + 6176) * 4 / 8)
470+
return ResourceUtilization(np.inf, np.inf, total_memory=(17920 + self.max_cut) * 4 / 8)
468471

469472
def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None):
470473
# verify chosen activation bitwidth config
@@ -485,7 +488,7 @@ def __init__(self, unit_test):
485488

486489
def get_resource_utilization(self):
487490
weights = 17920 * 4 / 8
488-
activation = 6176 * 4 / 8
491+
activation = math.ceil(self.max_cut * 4 / 8)
489492
return ResourceUtilization(weights, activation, total_memory=weights + activation)
490493

491494
def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None):
@@ -514,7 +517,7 @@ def __init__(self, unit_test):
514517

515518
def get_resource_utilization(self):
516519
weights = 17920 * 4 / 8
517-
activation = 6176 * 4 / 8 # max cut of input + conv_bn
520+
activation = math.ceil(self.max_cut * 4 / 8)
518521
return ResourceUtilization(weights, activation, total_memory=(weights + activation) / 2)
519522

520523
def _compare(self, quantized_model, float_model, input_x=None, quantization_info: UserInformation = None):

tests/keras_tests/feature_networks_tests/feature_networks/qat/qat_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,13 @@ def compare(self, qat_model, finalize=False, input_x=None, quantization_info=Non
286286

287287

288288
class QATWrappersMixedPrecisionCfgTest(MixedPrecisionActivationBaseTest):
289-
def __init__(self, unit_test, ru_weights=17919, ru_activation=5407, expected_mp_cfg=[0, 4, 0, 0]):
290-
self.ru_weights = ru_weights
291-
self.ru_activation = ru_activation
292-
self.expected_mp_cfg = expected_mp_cfg
289+
def __init__(self, unit_test, ru_weights=17919, ru_activation=None, expected_mp_cfg=None):
293290
super().__init__(unit_test, activation_layers_idx=[1, 3, 6])
291+
self.ru_weights = ru_weights
292+
# The default test case is that the max cut (which is the fused conv-relu layer tensors, input and output)
293+
# must be reduced to 4 bits on average.
294+
self.ru_activation = ru_activation or (self.max_cut * 4 / 8)
295+
self.expected_mp_cfg = expected_mp_cfg or [0, 4, 0, 1] # input, conv, conv2, relu
294296

295297
def run_test(self, **kwargs):
296298
model_float = self.create_networks()

tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import abc
16+
import math
1617

1718
import typing
1819

@@ -46,6 +47,7 @@
4647
class MixedPrecisionBaseTest(BaseKerasFeatureNetworkTest):
4748
def __init__(self, unit_test, val_batch_size=1, num_calibration_iter=1):
4849
super().__init__(unit_test, val_batch_size=val_batch_size, num_calibration_iter=num_calibration_iter)
50+
self.max_cut = 10 * 10 * 32 + 13 * 13 * 32
4951

5052
def get_quantization_config(self):
5153
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE,
@@ -361,7 +363,7 @@ class MixedPrecisionSearchTotalMemoryNonConfNodesTest(MixedPrecisionBaseTest):
361363
def __init__(self, unit_test):
362364
super().__init__(unit_test)
363365
# Total ResourceUtilization for weights in 2 bit avg and non-configurable activation in 8 bit
364-
self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + 6176)
366+
self.target_total_ru = ResourceUtilization(total_memory=17920 * 2 / 8 + math.ceil(self.max_cut * 8 / 8))
365367

366368
def get_resource_utilization(self):
367369
return self.target_total_ru

tests/keras_tests/feature_networks_tests/test_features_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,8 +831,8 @@ def test_qat(self):
831831
QuantizationAwareTrainingQuantizersTest(self).run_test()
832832
QuantizationAwareTrainingQuantizerHolderTest(self).run_test()
833833
QATWrappersMixedPrecisionCfgTest(self).run_test()
834-
QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=5408 * 4 / 8,
835-
expected_mp_cfg=[0, 5, 1, 1]).run_test()
834+
QATWrappersMixedPrecisionCfgTest(self, ru_weights=17920 * 4 / 8, ru_activation=8608 * 4 / 8,
835+
expected_mp_cfg=[0, 4, 1, 1]).run_test()
836836

837837
def test_bn_attributes_quantization(self):
838838
BNAttributesQuantization(self, quantize_linear=False).run_test()

tests/keras_tests/graph_tests/test_memory_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_memory_graph_build(self):
8484

8585
self.assertTrue(len(memory_graph.a_nodes) == 4)
8686
self.assertTrue(len(memory_graph.b_nodes) == 4)
87-
self.assertTrue(graph.get_topo_sorted_nodes()[0] in memory_graph.sources_a)
87+
self.assertTrue(graph.get_topo_sorted_nodes()[0].name in [node.name for node in memory_graph.sources_a])
8888
self.assertTrue(len(memory_graph.sinks_b) == 1)
8989
self.assertTrue(memory_graph.memory_lbound_single_op == 264)
9090

@@ -99,7 +99,7 @@ def test_memory_graph_node_with_multiple_outputs(self):
9999

100100
self.assertTrue(len(memory_graph.a_nodes) == 5)
101101
self.assertTrue(len(memory_graph.b_nodes) == 6)
102-
self.assertTrue(graph.get_topo_sorted_nodes()[0] in memory_graph.sources_a)
102+
self.assertTrue(graph.get_topo_sorted_nodes()[0].name in [node.name for node in memory_graph.sources_a])
103103
self.assertTrue(len(memory_graph.sinks_b) == 1)
104104
self.assertTrue(memory_graph.memory_lbound_single_op == 576)
105105

@@ -117,7 +117,7 @@ def test_memory_graph_with_residual(self):
117117

118118
self.assertTrue(len(memory_graph.a_nodes) == 5)
119119
self.assertTrue(len(memory_graph.b_nodes) == 5)
120-
self.assertTrue(graph.get_topo_sorted_nodes()[0] in memory_graph.sources_a)
120+
self.assertTrue(graph.get_topo_sorted_nodes()[0].name in [node.name for node in memory_graph.sources_a])
121121
self.assertTrue(len(memory_graph.sinks_b) == 1)
122122
self.assertTrue(memory_graph.memory_lbound_single_op == 199)
123123

0 commit comments

Comments
 (0)