1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15- < << << << HEAD
1615import torch
1716from torch import nn
1817from model_compression_toolkit .target_platform_capabilities .targetplatform2framework .attach2pytorch import \
@@ -69,66 +68,18 @@ def test_get_mac(minimal_tpc):
6968 fw_impl = PytorchImplementation ()
7069 fw_info = DEFAULT_PYTORCH_INFO
7170 model = Model ()
72- == == == =
73- import numpy as np
74- from keras .layers import Conv2D , Conv2DTranspose , DepthwiseConv2D , Dense , Input , Flatten
75- import keras
76-
77- from model_compression_toolkit .core import QuantizationConfig
78- from model_compression_toolkit .core .graph_prep_runner import graph_preparation_runner
79- from model_compression_toolkit .core .keras .default_framework_info import DEFAULT_KERAS_INFO
80- from model_compression_toolkit .core .keras .keras_implementation import KerasImplementation
81- from model_compression_toolkit .target_platform_capabilities .targetplatform2framework .attach2keras import \
82- AttachTpcToKeras
83-
84-
85- def data_gen ():
86- yield [np .random .randn (28 , 32 , 10 )]
87-
88-
89- def build_model ():
90- x = Input (shape = (28 , 32 , 10 ))
91- y = Conv2D (filters = 20 , kernel_size = (5 , 4 ))(x )
92- y = Conv2D (filters = 15 , kernel_size = (4 , 6 ), groups = 5 )(y )
93- y = Conv2D (filters = 8 , kernel_size = (3 , 3 ), strides = 2 )(y )
94- y = Conv2D (filters = 12 , kernel_size = (3 , 3 ), dilation_rate = 2 )(y )
95- y = Conv2DTranspose (filters = 20 , kernel_size = (5 , 3 ))(y )
96- y = Conv2DTranspose (filters = 10 , kernel_size = (3 , 3 ), strides = 2 )(y )
97- y = Conv2DTranspose (filters = 5 , kernel_size = (3 , 3 ), dilation_rate = 2 )(y )
98- y = DepthwiseConv2D (kernel_size = (2 , 3 ), depth_multiplier = 4 )(y )
99- y = DepthwiseConv2D (kernel_size = (3 , 3 ), depth_multiplier = 2 , strides = 3 )(y )
100- y = DepthwiseConv2D (kernel_size = (3 , 3 ), depth_multiplier = 2 , dilation_rate = 2 )(y )
101- y = Dense (10 )(y )
102- y = Flatten ()(y )
103- y = Dense (5 )(y )
104- return keras .Model (inputs = x , outputs = y )
105-
106-
107- def test_get_mac (minimal_tpc ):
108- fw_impl = KerasImplementation ()
109- model = build_model ()
110- fw_info = DEFAULT_KERAS_INFO
111- > >> >> >> compute bops on activation with multiple outputs
11271
11372 graph = graph_preparation_runner (model ,
11473 data_gen ,
11574 QuantizationConfig (linear_collapsing = False ),
11675 fw_info = fw_info ,
11776 fw_impl = fw_impl ,
118- << << << < HEAD
11977 fqc = AttachTpcToPytorch ().attach (minimal_tpc ),
120- == == == =
121- fqc = AttachTpcToKeras ().attach (minimal_tpc ),
122- >> >> >> > compute bops on activation with multiple outputs
12378 mixed_precision_enable = False ,
12479 running_gptq = False )
12580
12681 nodes = graph .get_topo_sorted_nodes ()
127- << << < << HEAD
12882 # assert len(nodes) == 14, nodes
129- == == == =
130- assert len (nodes ) == 14 , nodes
131- >> >> > >> compute bops on activation with multiple outputs
13283 assert fw_impl .get_node_mac_operations (nodes [0 ], fw_info ) == 0
13384 assert fw_impl .get_node_mac_operations (nodes [1 ], fw_info ) == (10 * 20 * 5 * 4 )* 24 * 29
13485 assert fw_impl .get_node_mac_operations (nodes [2 ], fw_info ) == (4 * 3 * 4 * 6 )* 5 * 21 * 24
@@ -140,17 +91,10 @@ def test_get_mac(minimal_tpc):
14091 assert fw_impl .get_node_mac_operations (nodes [8 ], fw_info ) == (5 * 2 * 3 * 4 )* 24 * 21
14192 assert fw_impl .get_node_mac_operations (nodes [9 ], fw_info ) == (10 * 3 * 3 * 4 )* 8 * 7
14293 assert fw_impl .get_node_mac_operations (nodes [10 ], fw_info ) == (40 * 3 * 3 * 2 )* 4 * 3
143- << < << << HEAD
14494 assert fw_impl .get_node_mac_operations (nodes [10 ], fw_info ) == (40 * 3 * 3 * 2 )* 4 * 3
14595 assert fw_impl .get_node_mac_operations (nodes [11 ], fw_info ) == 0
14696 assert fw_impl .get_node_mac_operations (nodes [12 ], fw_info ) == 4 * 3 * (80 * 10 )
14797 assert fw_impl .get_node_mac_operations (nodes [13 ], fw_info ) == 0
14898 assert fw_impl .get_node_mac_operations (nodes [14 ], fw_info ) == (4 * 3 * 10 )* 5
149- == == == =
150- assert fw_impl .get_node_mac_operations (nodes [11 ], fw_info ) == 4 * 3 * (80 * 10 )
151- assert fw_impl .get_node_mac_operations (nodes [12 ], fw_info ) == 0
152- assert fw_impl .get_node_mac_operations (nodes [13 ], fw_info ) == (4 * 3 * 80 * 10 )* 5
153-
154- >> >> > >> compute bops on activation with multiple outputs
15599
156100
0 commit comments