@@ -6,7 +6,12 @@ class GenerateConvStreamingInstructions(OptimizerPass):
66 '''Generates the instructions for streaming implementation of CNNs'''
77
88 def match (self , node ):
9- return isinstance (node , (Conv1D , SeparableConv1D , Conv2D , SeparableConv2D ))
9+ is_match = (
10+ isinstance (node , (Conv1D , SeparableConv1D , Conv2D , SeparableConv2D ))
11+ and node .model .config .get_config_value ('IOType' ).lower () == 'io_stream'
12+ and node .get_attr ('implementation' ).lower () == 'encoded'
13+ )
14+ return is_match
1015
1116 def transform (self , model , node ):
1217 node_class = node .__class__ .__name__
@@ -18,35 +23,25 @@ def transform(self, model, node):
1823 raise Exception (f'Cannot generate instructions for node { node .name } ({ node_class } )' )
1924
2025 def _generate_1d_instructions (self , node ):
21- if node .model .config .get_config_value ('IOType' ) == 'io_stream' :
22- min_w , instructions = node .model .config .backend .compute_conv1d_instructions (
23- node .get_input_variable ().shape [0 ],
24- node .get_input_variable ().shape [1 ],
25- node .get_attr ('filt_width' ),
26- node .get_attr ('stride_width' ),
27- )
28- instructions_str = ',' .join (str (i ) for i in instructions )
29- node .set_attr ('min_width' , min_w )
30- node .set_attr ('instructions' , instructions_str )
31- else :
32- # these are unused; just put dummy values
33- node .set_attr ('min_width' , node .get_attr ('in_width' ))
34- node .set_attr ('instructions' , '0' )
26+ min_w , instructions = node .model .config .backend .compute_conv1d_instructions (
27+ node .get_input_variable ().shape [0 ],
28+ node .get_input_variable ().shape [1 ],
29+ node .get_attr ('filt_width' ),
30+ node .get_attr ('stride_width' ),
31+ )
32+ instructions_str = ',' .join (str (i ) for i in instructions )
33+ node .set_attr ('min_width' , min_w )
34+ node .set_attr ('instructions' , instructions_str )
3535
3636 def _generate_2d_instructions (self , node ):
37- if node .model .config .get_config_value ('IOType' ) == 'io_stream' :
38- min_h , min_w , instructions = node .model .config .backend .compute_conv2d_instructions (
39- node .get_input_variable ().shape [0 ],
40- node .get_input_variable ().shape [1 ],
41- node .get_input_variable ().shape [2 ],
42- node .get_attr ('filt_height' ),
43- node .get_attr ('stride_height' ),
44- )
45- instructions_str = ',' .join (str (i ) for i in instructions )
46- node .set_attr ('min_height' , min_h )
47- node .set_attr ('min_width' , min_w )
48- node .set_attr ('instructions' , instructions_str )
49- else :
50- node .set_attr ('min_height' , node .get_attr ('in_height' ))
51- node .set_attr ('min_width' , node .get_attr ('in_width' ))
52- node .set_attr ('instructions' , '0' )
37+ min_h , min_w , instructions = node .model .config .backend .compute_conv2d_instructions (
38+ node .get_input_variable ().shape [0 ],
39+ node .get_input_variable ().shape [1 ],
40+ node .get_input_variable ().shape [2 ],
41+ node .get_attr ('filt_height' ),
42+ node .get_attr ('stride_height' ),
43+ )
44+ instructions_str = ',' .join (str (i ) for i in instructions )
45+ node .set_attr ('min_height' , min_h )
46+ node .set_attr ('min_width' , min_w )
47+ node .set_attr ('instructions' , instructions_str )
0 commit comments