@@ -1066,9 +1066,9 @@ def test_compute_node_bops_default_qc(self, fw_impl_mock, fw_info_mock):
10661066
10671067 def test_compute_virtual_aw_node_bops_fully_quantized (self , fw_impl_mock , fw_info_mock ):
10681068 # all quantized
1069- g , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
1070- quantize_a1 = True , quantize_w1 = True ,
1071- quantize_a2 = True , quantize_w2 = True )
1069+ g , _ , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
1070+ quantize_a1 = True , quantize_w1 = True ,
1071+ quantize_a2 = True , quantize_w2 = True )
10721072 ru_calc = ResourceUtilizationCalculator (g , fw_impl_mock , fw_info_mock )
10731073
10741074 assert ru_calc .compute_node_bops (a1w2 , TIC .AnyQuantized , BM .Float ) == 42 * 32 * 32
@@ -1083,18 +1083,16 @@ def test_compute_virtual_aw_node_bops_fully_quantized(self, fw_impl_mock, fw_inf
10831083
10841084 assert ru_calc .compute_node_bops (a2 , TIC .AnyQuantized , BM .Float ) == 0
10851085
1086- assert ru_calc .compute_node_bops (w3 , TIC .AnyQuantized , BM .QMaxBit ) == 142 * 7 * 6
1087-
10881086 def test_compute_virtual_aw_node_bops_half_quantized (self , fw_impl_mock , fw_info_mock ):
1089- g , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
1087+ g , _ , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
10901088 quantize_a1 = True , quantize_w1 = False ,
10911089 quantize_a2 = False , quantize_w2 = True )
10921090 ru_calc = ResourceUtilizationCalculator (g , fw_impl_mock , fw_info_mock )
10931091 assert ru_calc .compute_node_bops (a1w2 , TIC .AnyQuantized , BM .QMaxBit ) == 42 * 16 * 32
10941092 assert ru_calc .compute_node_bops (a2w3 , TIC .AnyQuantized , BM .QMaxBit ) == 142 * 32 * 6
10951093
10961094 def test_compute_virtual_aw_node_bops_unquantized (self , fw_impl_mock , fw_info_mock ):
1097- g , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
1095+ g , _ , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
10981096 quantize_a1 = False , quantize_w1 = False ,
10991097 quantize_a2 = False , quantize_w2 = False )
11001098 ru_calc = ResourceUtilizationCalculator (g , fw_impl_mock , fw_info_mock )
@@ -1105,7 +1103,7 @@ def test_compute_virtual_aw_node_bops_unquantized(self, fw_impl_mock, fw_info_mo
11051103 assert ru_calc .compute_node_bops (a2w3 , TIC .Any , BM .QMaxBit ) == 142 * 32 * 32
11061104
11071105 def test_compute_virtual_aw_node_bops_custom (self , fw_impl_mock , fw_info_mock ):
1108- g , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
1106+ g , _ , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock ,
11091107 quantize_a1 = False , quantize_w1 = False ,
11101108 quantize_a2 = True , quantize_w2 = True )
11111109 custom_qc_a1w2 = build_qc (5 , w_attr = {'foo' : (6 , True )})
@@ -1179,17 +1177,57 @@ class BOPNode2:
11791177 'n3' : 630 * 7 * 5 }
11801178
11811179 def test_compute_virtual_graph_resources (self , fw_impl_mock , fw_info_mock ):
1182- g , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock , True , True , True , True )
1180+ g , _ , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock , True , True , True , True )
11831181 ru_calc = ResourceUtilizationCalculator (g , fw_impl_mock , fw_info_mock )
11841182 ru , detailed = ru_calc .compute_resource_utilization (TIC .Any , BM .QMaxBit , return_detailed = True )
11851183 assert (sorted (list (detailed [RUTarget .ACTIVATION ].values ())) ==
1186- sorted ([24 , 24 + 50 * 2 , 50 * 2 + 88 * 5 / 8 , 88 * 5 / 8 + 24 * 7 / 8 , 24 * 7 / 8 , 0 ])), detailed [RUTarget .ACTIVATION ]
1187- assert detailed [RUTarget .WEIGHTS ] == {a1w2 .name : 42 * 2 , a2w3 .name : 142 * 6 / 8 , w3 . name : 142 * 6 / 8 }
1188- assert detailed [RUTarget .BOPS ] == {a1w2 .name : 42 * 16 * 16 , a2w3 .name : 142 * 5 * 6 , w3 . name : 142 * 7 * 6 }
1189- assert ru == ResourceUtilization (weights_memory = 84 + 142 * 1.5 ,
1184+ sorted ([24 , 24 + 50 * 2 , 50 * 2 + 88 * 5 / 8 , 88 * 5 / 8 + 28 * 7 / 8 , 28 * 7 / 8 ])), detailed [RUTarget .ACTIVATION ]
1185+ assert detailed [RUTarget .WEIGHTS ] == {a1w2 .name : 42 * 2 , a2w3 .name : 142 * 6 / 8 }
1186+ assert detailed [RUTarget .BOPS ] == {a1w2 .name : 42 * 16 * 16 , a2w3 .name : 142 * 5 * 6 }
1187+ assert ru == ResourceUtilization (weights_memory = 84 + 142 * 6 / 8 ,
11901188 activation_memory = 155 ,
1191- total_memory = 155 + 297 ,
1192- bops = 42 * 256 + 142 * 30 + 142 * 42 )
1189+ total_memory = 155 + (84 + 142 * 6 / 8 ),
1190+ bops = 42 * 256 + 142 * 30 )
1191+
1192+ def test_virtual_graph_with_virtual_weight (self , fw_impl_mock , fw_info_mock ):
1193+ # virtual weight node wasn't merged into virtual composed node
1194+ _ , n_in , a1w2 , a2 , a2w3 , w3 , a3 = self ._build_virtual_node_graph (fw_impl_mock , fw_info_mock , True , True , True , True )
1195+ g = Graph ('g' , nodes = [w3 ], input_nodes = [n_in ], output_nodes = [a3 ],
1196+ edge_list = [Edge (n_in , w3 , 0 , 0 ), Edge (w3 , a3 , 0 , 0 )])
1197+ ru_calc = ResourceUtilizationCalculator (g , fw_impl_mock , fw_info_mock )
1198+ ru , detailed = ru_calc .compute_resource_utilization (TIC .Any , BM .QMaxBit , return_detailed = True )
1199+ assert list (detailed [RUTarget .WEIGHTS ].values ()) == [142 * 6 / 8 ]
1200+ assert detailed [RUTarget .BOPS ] == {}
1201+ # the extra cut that is created by virtual weight node. The rest of the cuts must be correct.
1202+ wa_cut = 2 * 28 * 7 / 8
1203+ assert sorted (list (detailed [RUTarget .ACTIVATION ].values ())) == sorted ([24 , 24 + 28 * 7 / 8 , 28 * 7 / 8 , wa_cut ])
1204+ assert ru == ResourceUtilization (weights_memory = 142 * 6 / 8 ,
1205+ activation_memory = wa_cut ,
1206+ total_memory = wa_cut + (142 * 6 / 8 ),
1207+ bops = 0 )
1208+
1209+ def test_multi_output_input_activation (self , fw_impl_mock , fw_info_mock ):
1210+ """ No bops should be calculated for weight node if its input activation has multiple outputs. """
1211+ n_in = build_node ('in' , qcs = [build_qc ()], output_shape = (None , 2 , 3 , 4 ))
1212+ n2 = build_node ('n2' , layer_class = BOPNode , output_shape = (None , 2 , 44 ),
1213+ canonical_weights = {'foo' : np .zeros ((3 , 14 ))},
1214+ qcs = [
1215+ build_qc (2 , w_attr = {'foo' : (16 , True )}),
1216+ build_qc (3 , w_attr = {'foo' : (10 , True )}),
1217+ build_qc (4 , w_attr = {'foo' : (7 , True )}),
1218+ build_qc (5 , w_attr = {'foo' : (6 , True )}),
1219+ ])
1220+ n_out = build_node ('out' , qcs = [build_qc ()], output_shape = (None , 27 ))
1221+ g = Graph ('g' , input_nodes = [n_in ], nodes = [n2 ], output_nodes = [n_out ],
1222+ edge_list = [Edge (n_in , n2 , 0 , 0 ), Edge (n_in , n_out , 0 , 0 )])
1223+
1224+ def get_kernel_attr (node_type ):
1225+ return {BOPNode : ['foo' ]}.get (node_type ) or []
1226+ fw_info_mock .get_kernel_op_attributes = get_kernel_attr
1227+ fw_impl_mock .get_node_mac_operations = lambda n , fw_info : {n2 : 42 }.get (n , 0 )
1228+
1229+ ru_calc = ResourceUtilizationCalculator (g , fw_impl_mock , fw_info_mock )
1230+ assert ru_calc .compute_bops (TIC .Any , BM .Float ) == (0 , {})
11931231
11941232 def _build_regular_node_graph (self , enable_aq , enable_wq ):
11951233 n1 = build_node ('n1' , qcs = [build_qc (16 , enable_aq ), build_qc (7 , enable_aq )], output_shape = (None , 5 , 10 ))
@@ -1225,7 +1263,7 @@ class ActType:
12251263 build_qc (4 , quantize_a2 , w_attr = {'foo' : (7 , quantize_w1 )}),
12261264 build_qc (5 , quantize_a2 , w_attr = {'foo' : (6 , quantize_w1 )}),
12271265 ])
1228- n3 = build_node ('n3' , layer_class = BOPNode2 , output_shape = (None , 24 ),
1266+ n3 = build_node ('n3' , layer_class = BOPNode2 , output_shape = (None , 28 ),
12291267 canonical_weights = {'bar' : np .zeros ((2 , 71 ))},
12301268 qcs = [
12311269 build_qc (4 , w_attr = {'bar' : (6 , quantize_w2 )}),
@@ -1240,12 +1278,12 @@ def get_kernel_attr(node_type):
12401278 fw_impl_mock .get_node_mac_operations = lambda n , fw_info : {n2 : 42 , n3 : 142 }.get (n , 0 )
12411279
12421280 # virtual aw node made of original nodes
1243- a1w2 = VirtualActivationWeightsNode (act_node = n1 , weights_node = n2 , fw_info = fw_info_mock , ** n2 . __dict__ )
1281+ a1w2 = VirtualActivationWeightsNode (act_node = n1 , weights_node = n2 , fw_info = fw_info_mock )
12441282 a2 = VirtualSplitActivationNode (n2 , ActType , {})
12451283 w3 = VirtualSplitWeightsNode (n3 , 'bar' )
12461284 # virtual aw node made of virtual split a, w nodes
1247- a2w3 = VirtualActivationWeightsNode (act_node = a2 , weights_node = w3 , fw_info = fw_info_mock , ** n3 . __dict__ )
1285+ a2w3 = VirtualActivationWeightsNode (act_node = a2 , weights_node = w3 , fw_info = fw_info_mock )
12481286 a3 = VirtualSplitActivationNode (n3 , ActType , {})
12491287 g = Graph ('g' , nodes = [a1w2 , a2w3 , a3 ], input_nodes = [n_in ], output_nodes = [w3 ],
1250- edge_list = [Edge (n_in , a1w2 , 0 , 0 ), Edge (a1w2 , a2w3 , 0 , 0 ), Edge (a2w3 , a3 , 0 , 0 ), Edge ( a3 , w3 , 0 , 0 ) ])
1251- return g , a1w2 , a2 , a2w3 , w3 , a3
1288+ edge_list = [Edge (n_in , a1w2 , 0 , 0 ), Edge (a1w2 , a2w3 , 0 , 0 ), Edge (a2w3 , a3 , 0 , 0 )])
1289+ return g , n_in , a1w2 , a2 , a2w3 , w3 , a3
0 commit comments