Skip to content

Commit d10483b

Browse files
irenabirenab
authored andcommitted
dont compute bops for virtual weights node
1 parent 616631c commit d10483b

File tree

4 files changed

+110
-43
lines changed

4 files changed

+110
-43
lines changed

model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
RUTarget, ResourceUtilization
3333
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
3434
NodeActivationQuantizationConfig
35+
from model_compression_toolkit.core.common.substitutions.virtual_activation_weights_composition import \
36+
BaseVirtualActivationWeightsComposition, get_input_activation_if_composable
3537

3638

3739
class BitwidthMode(Enum):
@@ -510,13 +512,19 @@ def compute_node_bops(self,
510512
if w_qc and bitwidth_mode != BitwidthMode.QCustom:
511513
raise ValueError(self.unexpected_qc_error)
512514

513-
# extract the original weight node for mac computation
515+
if isinstance(n, VirtualSplitWeightsNode):
516+
# Virtual weights node can only be present if it couldn't be merged into VirtualActivationWeightsNode.
517+
# This means that during MP search we cannot compute bops for all A/W nbits combinations. To prevent
518+
# inconsistencies we ignore such nodes for bops computation.
519+
return 0
520+
521+
# Fetch the original weights node for mac computation (VirtualActivationWeightsNode input/output shapes are
522+
# based on the activation original node, not weights original node)
514523
orig_w_node = n
515524
if isinstance(n, VirtualActivationWeightsNode):
516525
orig_w_node = n.original_weights_node
517-
518-
if isinstance(orig_w_node, VirtualSplitWeightsNode):
519-
orig_w_node = orig_w_node.origin_node
526+
if isinstance(orig_w_node, VirtualSplitWeightsNode):
527+
orig_w_node = orig_w_node.origin_node
520528

521529
# check if the node has kernel
522530
kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type)
@@ -535,10 +543,9 @@ def compute_node_bops(self,
535543
# we don't need the original node (and cannot use it for custom configuration anyway)
536544
a_node = n
537545
else:
538-
incoming_edges = self.graph.incoming_edges(n)
539-
assert len(incoming_edges) == 1, \
540-
f'Unexpected number of inputs {len(incoming_edges)} for BOPS calculation. Expected 1.'
541-
a_node = incoming_edges[0].source_node
546+
a_node = get_input_activation_if_composable(self.graph, n, warn=False)
547+
if a_node is None:
548+
return 0
542549

543550
if (target_criterion == TargetInclusionCriterion.AnyQuantized and
544551
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):

model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from typing import Optional
1516

1617
from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
1718
from model_compression_toolkit.logger import Logger
@@ -44,14 +45,8 @@ def substitute(self,
4445
raise TypeError(f'Matched node {weights_node} was expected to be of type VirtualSplitWeightsNode. '
4546
f'This substitution is expected to be called after activation-weights split.')
4647

47-
predecessors = graph.get_prev_nodes(weights_node)
48-
assert len(predecessors) == 1, (f'Matched node for {self.__class__.__name__} substitution is expected to have'
49-
f'exactly one input, node {weights_node} has {len(predecessors)}')
50-
act_node = predecessors[0]
51-
if len(graph.out_edges(act_node)) > 1:
52-
Logger.warning(f"Node {act_node.name} has multiple outgoing edges, which is not supported with "
53-
f"mixed-precision search under bit-operations constraint. In such case, it might result in "
54-
f"incorrect resource utilization computation and suboptimal bits selection.")
48+
act_node = get_input_activation_if_composable(graph, weights_node, warn=True)
49+
if act_node is None:
5550
return graph
5651

5752
# Virtual composed activation-weights node
@@ -70,3 +65,29 @@ def substitute(self,
7065
graph.remove_node(act_node)
7166

7267
return graph
68+
69+
70+
def get_input_activation_if_composable(graph: Graph, weights_node: BaseNode, warn: bool) -> Optional[BaseNode]:
71+
"""
72+
Get input activation node for composition, or None if not composable.
73+
74+
Args:
75+
graph: graph.
76+
weights_node: weights node for composition.
77+
warn: whether to log a warning if not composable.
78+
79+
Returns:
80+
Input activation node or None.
81+
"""
82+
predecessors = graph.get_prev_nodes(weights_node)
83+
assert len(predecessors) == 1, (f'Weights node is expected to have exactly one input, '
84+
f'node {weights_node} has {len(predecessors)}')
85+
act_node = predecessors[0]
86+
if len(graph.out_edges(act_node)) > 1:
87+
if warn:
88+
Logger.warning(f"Node {act_node.name} has multiple outgoing edges, which is not supported with "
89+
f"mixed-precision search under bit-operations constraint. In such case, it might result in "
90+
f"incorrect resource utilization computation and suboptimal bits selection.")
91+
return None
92+
93+
return act_node

tests_pytest/common/core/common/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,9 +1066,9 @@ def test_compute_node_bops_default_qc(self, fw_impl_mock, fw_info_mock):
10661066

10671067
def test_compute_virtual_aw_node_bops_fully_quantized(self, fw_impl_mock, fw_info_mock):
10681068
# all quantized
1069-
g, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
1070-
quantize_a1=True, quantize_w1=True,
1071-
quantize_a2=True, quantize_w2=True)
1069+
g, _, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
1070+
quantize_a1=True, quantize_w1=True,
1071+
quantize_a2=True, quantize_w2=True)
10721072
ru_calc = ResourceUtilizationCalculator(g, fw_impl_mock, fw_info_mock)
10731073

10741074
assert ru_calc.compute_node_bops(a1w2, TIC.AnyQuantized, BM.Float) == 42 * 32 * 32
@@ -1083,18 +1083,16 @@ def test_compute_virtual_aw_node_bops_fully_quantized(self, fw_impl_mock, fw_inf
10831083

10841084
assert ru_calc.compute_node_bops(a2, TIC.AnyQuantized, BM.Float) == 0
10851085

1086-
assert ru_calc.compute_node_bops(w3, TIC.AnyQuantized, BM.QMaxBit) == 142 * 7 * 6
1087-
10881086
def test_compute_virtual_aw_node_bops_half_quantized(self, fw_impl_mock, fw_info_mock):
1089-
g, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
1087+
g, _, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
10901088
quantize_a1=True, quantize_w1=False,
10911089
quantize_a2=False, quantize_w2=True)
10921090
ru_calc = ResourceUtilizationCalculator(g, fw_impl_mock, fw_info_mock)
10931091
assert ru_calc.compute_node_bops(a1w2, TIC.AnyQuantized, BM.QMaxBit) == 42 * 16 * 32
10941092
assert ru_calc.compute_node_bops(a2w3, TIC.AnyQuantized, BM.QMaxBit) == 142 * 32 * 6
10951093

10961094
def test_compute_virtual_aw_node_bops_unquantized(self, fw_impl_mock, fw_info_mock):
1097-
g, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
1095+
g, _, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
10981096
quantize_a1=False, quantize_w1=False,
10991097
quantize_a2=False, quantize_w2=False)
11001098
ru_calc = ResourceUtilizationCalculator(g, fw_impl_mock, fw_info_mock)
@@ -1105,7 +1103,7 @@ def test_compute_virtual_aw_node_bops_unquantized(self, fw_impl_mock, fw_info_mo
11051103
assert ru_calc.compute_node_bops(a2w3, TIC.Any, BM.QMaxBit) == 142 * 32 * 32
11061104

11071105
def test_compute_virtual_aw_node_bops_custom(self, fw_impl_mock, fw_info_mock):
1108-
g, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
1106+
g, _, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock,
11091107
quantize_a1=False, quantize_w1=False,
11101108
quantize_a2=True, quantize_w2=True)
11111109
custom_qc_a1w2 = build_qc(5, w_attr={'foo': (6, True)})
@@ -1179,17 +1177,57 @@ class BOPNode2:
11791177
'n3': 630 * 7 * 5}
11801178

11811179
def test_compute_virtual_graph_resources(self, fw_impl_mock, fw_info_mock):
1182-
g, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock, True, True, True, True)
1180+
g, _, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock, True, True, True, True)
11831181
ru_calc = ResourceUtilizationCalculator(g, fw_impl_mock, fw_info_mock)
11841182
ru, detailed = ru_calc.compute_resource_utilization(TIC.Any, BM.QMaxBit, return_detailed=True)
11851183
assert (sorted(list(detailed[RUTarget.ACTIVATION].values())) ==
1186-
sorted([24, 24 + 50*2, 50*2+88*5/8, 88*5/8 + 24*7/8, 24*7/8, 0])), detailed[RUTarget.ACTIVATION]
1187-
assert detailed[RUTarget.WEIGHTS] == {a1w2.name: 42*2, a2w3.name: 142*6/8, w3.name: 142*6/8}
1188-
assert detailed[RUTarget.BOPS] == {a1w2.name: 42*16*16, a2w3.name: 142*5*6, w3.name: 142*7*6}
1189-
assert ru == ResourceUtilization(weights_memory=84 + 142*1.5,
1184+
sorted([24, 24 + 50*2, 50*2+88*5/8, 88*5/8 + 28*7/8, 28*7/8])), detailed[RUTarget.ACTIVATION]
1185+
assert detailed[RUTarget.WEIGHTS] == {a1w2.name: 42*2, a2w3.name: 142*6/8}
1186+
assert detailed[RUTarget.BOPS] == {a1w2.name: 42*16*16, a2w3.name: 142*5*6}
1187+
assert ru == ResourceUtilization(weights_memory=84 + 142*6/8,
11901188
activation_memory=155,
1191-
total_memory=155+297,
1192-
bops=42*256+142*30+142*42)
1189+
total_memory=155 + (84 + 142*6/8),
1190+
bops=42*256+142*30)
1191+
1192+
def test_virtual_graph_with_virtual_weight(self, fw_impl_mock, fw_info_mock):
1193+
# virtual weight node wasn't merged into virtual composed node
1194+
_, n_in, a1w2, a2, a2w3, w3, a3 = self._build_virtual_node_graph(fw_impl_mock, fw_info_mock, True, True, True, True)
1195+
g = Graph('g', nodes=[w3], input_nodes=[n_in], output_nodes=[a3],
1196+
edge_list=[Edge(n_in, w3, 0, 0), Edge(w3, a3, 0, 0)])
1197+
ru_calc = ResourceUtilizationCalculator(g, fw_impl_mock, fw_info_mock)
1198+
ru, detailed = ru_calc.compute_resource_utilization(TIC.Any, BM.QMaxBit, return_detailed=True)
1199+
assert list(detailed[RUTarget.WEIGHTS].values()) == [142 * 6 / 8]
1200+
assert detailed[RUTarget.BOPS] == {}
1201+
# the extra cut that is created by virtual weight node. The rest of the cuts must be correct.
1202+
wa_cut = 2*28*7/8
1203+
assert sorted(list(detailed[RUTarget.ACTIVATION].values())) == sorted([24, 24+28*7/8, 28*7/8, wa_cut])
1204+
assert ru == ResourceUtilization(weights_memory=142 * 6 / 8,
1205+
activation_memory=wa_cut,
1206+
total_memory=wa_cut + (142 * 6 / 8),
1207+
bops=0)
1208+
1209+
def test_multi_output_input_activation(self, fw_impl_mock, fw_info_mock):
1210+
""" No bops should be calculated for weight node if its input activation has multiple outputs. """
1211+
n_in = build_node('in', qcs=[build_qc()], output_shape=(None, 2, 3, 4))
1212+
n2 = build_node('n2', layer_class=BOPNode, output_shape=(None, 2, 44),
1213+
canonical_weights={'foo': np.zeros((3, 14))},
1214+
qcs=[
1215+
build_qc(2, w_attr={'foo': (16, True)}),
1216+
build_qc(3, w_attr={'foo': (10, True)}),
1217+
build_qc(4, w_attr={'foo': (7, True)}),
1218+
build_qc(5, w_attr={'foo': (6, True)}),
1219+
])
1220+
n_out = build_node('out', qcs=[build_qc()], output_shape=(None, 27))
1221+
g = Graph('g', input_nodes=[n_in], nodes=[n2], output_nodes=[n_out],
1222+
edge_list=[Edge(n_in, n2, 0, 0), Edge(n_in, n_out, 0, 0)])
1223+
1224+
def get_kernel_attr(node_type):
1225+
return {BOPNode: ['foo']}.get(node_type) or []
1226+
fw_info_mock.get_kernel_op_attributes = get_kernel_attr
1227+
fw_impl_mock.get_node_mac_operations = lambda n, fw_info: {n2: 42}.get(n, 0)
1228+
1229+
ru_calc = ResourceUtilizationCalculator(g, fw_impl_mock, fw_info_mock)
1230+
assert ru_calc.compute_bops(TIC.Any, BM.Float) == (0, {})
11931231

11941232
def _build_regular_node_graph(self, enable_aq, enable_wq):
11951233
n1 = build_node('n1', qcs=[build_qc(16, enable_aq), build_qc(7, enable_aq)], output_shape=(None, 5, 10))
@@ -1225,7 +1263,7 @@ class ActType:
12251263
build_qc(4, quantize_a2, w_attr={'foo': (7, quantize_w1)}),
12261264
build_qc(5, quantize_a2, w_attr={'foo': (6, quantize_w1)}),
12271265
])
1228-
n3 = build_node('n3', layer_class=BOPNode2, output_shape=(None, 24),
1266+
n3 = build_node('n3', layer_class=BOPNode2, output_shape=(None, 28),
12291267
canonical_weights={'bar': np.zeros((2, 71))},
12301268
qcs=[
12311269
build_qc(4, w_attr={'bar': (6, quantize_w2)}),
@@ -1240,12 +1278,12 @@ def get_kernel_attr(node_type):
12401278
fw_impl_mock.get_node_mac_operations = lambda n, fw_info: {n2: 42, n3: 142}.get(n, 0)
12411279

12421280
# virtual aw node made of original nodes
1243-
a1w2 = VirtualActivationWeightsNode(act_node=n1, weights_node=n2, fw_info=fw_info_mock, **n2.__dict__)
1281+
a1w2 = VirtualActivationWeightsNode(act_node=n1, weights_node=n2, fw_info=fw_info_mock)
12441282
a2 = VirtualSplitActivationNode(n2, ActType, {})
12451283
w3 = VirtualSplitWeightsNode(n3, 'bar')
12461284
# virtual aw node made of virtual split a, w nodes
1247-
a2w3 = VirtualActivationWeightsNode(act_node=a2, weights_node=w3, fw_info=fw_info_mock, **n3.__dict__)
1285+
a2w3 = VirtualActivationWeightsNode(act_node=a2, weights_node=w3, fw_info=fw_info_mock)
12481286
a3 = VirtualSplitActivationNode(n3, ActType, {})
12491287
g = Graph('g', nodes=[a1w2, a2w3, a3], input_nodes=[n_in], output_nodes=[w3],
1250-
edge_list=[Edge(n_in, a1w2, 0, 0), Edge(a1w2, a2w3, 0, 0), Edge(a2w3, a3, 0, 0), Edge(a3, w3, 0, 0)])
1251-
return g, a1w2, a2, a2w3, w3, a3
1288+
edge_list=[Edge(n_in, a1w2, 0, 0), Edge(a1w2, a2w3, 0, 0), Edge(a2w3, a3, 0, 0)])
1289+
return g, n_in, a1w2, a2, a2w3, w3, a3

tests_pytest/keras/core/mixed_precision/test_resource_utilization.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def data_gen():
109109

110110
class TestRUIntegration:
111111
def test_orig_vs_virtual_sequential_graph(self):
112+
""" Test detailed ru computation on original and corresponding virtual graph. """
112113
inputs = Input(shape=(18, 18, 3))
113114
x = Conv2D(filters=8, kernel_size=5)(inputs)
114115
x = tf.add(x, np.ones((14, 8))) # => activation with const in the composed node
@@ -175,6 +176,8 @@ def test_orig_vs_virtual_sequential_graph(self):
175176
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops
176177

177178
def test_mult_output_activation(self):
179+
""" Tests the case when input activation has multiple outputs -> virtual weights nodes are not merged
180+
into VirtualActivationWeightsNode. """
178181
inputs = Input(shape=(16, 16, 3))
179182
x1 = Conv2D(filters=15, kernel_size=3, groups=3)(inputs)
180183
x2 = DepthwiseConv2D(kernel_size=3, depth_multiplier=5)(inputs)
@@ -197,20 +200,19 @@ def test_mult_output_activation(self):
197200
(14 * 14 * 15 * binary_out_a_bit + 14 * 14 * 10 * linear_a_min_nbit) / 8,
198201
14 * 14 * 10 * linear_a_min_nbit / 8]
199202

200-
# the order of conv and dwconv is not guaranteed but they have same values
203+
# the order of conv and dwconv is not guaranteed, but they have same values
201204
exp_w_ru = [3*3*1*15*linear_w_min_nbit/8,
202205
3*3*3*5*linear_w_min_nbit/8,
203206
15 * 10 * linear_w_min_nbit/8]
204-
exp_bops = [(3*3*1*15)*(14*14)*default_a_nbit*linear_w_min_nbit,
205-
(3*3*3*5)*(14*14)*default_a_nbit*linear_w_min_nbit,
206-
(15*10)*(14*14)*binary_out_a_bit*linear_w_min_nbit]
207+
# bops are not computed for virtual weights nodes
208+
exp_bops = [(15*10)*(14*14)*binary_out_a_bit*linear_w_min_nbit]
207209

208210
assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)
209211
assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru
210212
assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops
211213

212214
virtual_graph = substitute(copy.deepcopy(graph),
213-
fw_impl.get_substitutions_virtual_weights_activation_coupling())
215+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
214216
assert len(virtual_graph.nodes) == 7
215217
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 1
216218
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3
@@ -222,8 +224,7 @@ def test_mult_output_activation(self):
222224
return_detailed=True)
223225
assert ru_virtual == ru_orig
224226
# conv and dwconv each remain as a pair of virtual W and virtual A nodes. Remaining virtual W nodes mess up the
225-
# cuts. However, this should only add virtualW-virtualA cuts, all cuts from the original graph should be
226-
# identical
227+
# cuts - but this should only add virtualW-virtualA cuts, all cuts from the original graph should stay identical
227228
assert not set(exp_cuts_ru) - set(detailed_virtual[RUTarget.ACTIVATION].values())
228229
assert self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) == exp_w_ru
229230
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops

0 commit comments

Comments
 (0)