1- from hls4ml .model .layers import Conv1D , Conv2D
1+ from hls4ml .model .layers import Conv1D , Conv2D , SeparableConv1D , SeparableConv2D
22from hls4ml .model .optimizer import OptimizerPass
33from hls4ml .model .types import Source
44
@@ -7,16 +7,27 @@ class GenerateConvIm2col(OptimizerPass):
77 '''Generates tcode for im2col step of 1D/2d convolution'''
88
99 def match (self , node ):
10- return isinstance (node , (Conv1D , Conv2D )) and node .model .config .get_config_value ('IOType' ) == 'io_parallel'
10+ return (
11+ isinstance (node , (Conv1D , Conv2D , SeparableConv1D , SeparableConv2D ))
12+ and node .model .config .get_config_value ('IOType' ) == 'io_parallel'
13+ )
1114
1215 def transform (self , model , node ):
13- node_class = node .__class__ .__name__
14- if '1D' in node_class :
15- self ._generate_im2col_1d (node )
16- elif '2D' in node_class :
17- self ._generate_im2col_2d (node )
16+ node_class = node .class_name
17+ if 'Separable' in node_class :
18+ if '1D' in node_class :
19+ self ._generate_separable_im2col_1d (node )
20+ elif '2D' in node_class :
21+ self ._generate_separable_im2col_2d (node )
22+ else :
23+ raise Exception (f'Cannot generate instructions for node { node .name } ({ node_class } )' )
1824 else :
19- raise Exception (f'Cannot generate instructions for node { node .name } ({ node_class } )' )
25+ if '1D' in node_class :
26+ self ._generate_im2col_1d (node )
27+ elif '2D' in node_class :
28+ self ._generate_im2col_2d (node )
29+ else :
30+ raise Exception (f'Cannot generate instructions for node { node .name } ({ node_class } )' )
2031
2132 def _generate_im2col_1d (self , node ):
2233 code_str = node .model .config .backend .generate_conv1d_line_buffer_fn (
@@ -49,3 +60,56 @@ def _generate_im2col_2d(self, node):
4960 )
5061
5162 node .set_attr ('line_buffer_codegen' , Source (code_str ))
63+
64+ def _generate_separable_im2col_1d (self , node ):
65+ dw_code_str = node .model .config .backend .generate_conv1d_line_buffer_fn (
66+ str (node .get_attr ('index' )) + '_dw' ,
67+ node .get_attr ('n_partitions' ),
68+ node .get_input_variable ().shape [0 ],
69+ node .get_input_variable ().shape [1 ],
70+ kernel = node .get_attr ('filt_width' ),
71+ stride = node .get_attr ('stride_width' ),
72+ pad = (node .get_attr ('pad_left' ), node .get_attr ('pad_right' )),
73+ )
74+
75+ node .set_attr ('dw_line_buffer_codegen' , Source (dw_code_str ))
76+
77+ pw_code_str = node .model .config .backend .generate_conv1d_line_buffer_fn (
78+ str (node .get_attr ('index' )) + '_pw' ,
79+ node .get_attr ('n_partitions' ),
80+ node .get_output_variable ().shape [0 ],
81+ node .get_input_variable ().shape [1 ],
82+ kernel = 1 ,
83+ )
84+
85+ node .set_attr ('pw_line_buffer_codegen' , Source (pw_code_str ))
86+
87+ def _generate_separable_im2col_2d (self , node ):
88+ dw_code_str = node .model .config .backend .generate_conv2d_line_buffer_fn (
89+ str (node .get_attr ('index' )) + '_dw' ,
90+ node .get_attr ('n_partitions' ),
91+ node .get_input_variable ().shape [0 ],
92+ node .get_input_variable ().shape [1 ],
93+ node .get_input_variable ().shape [2 ],
94+ kernel = (node .get_attr ('filt_height' ), node .get_attr ('filt_width' )),
95+ stride = (node .get_attr ('stride_height' ), node .get_attr ('stride_width' )),
96+ pad = (
97+ node .get_attr ('pad_top' ),
98+ node .get_attr ('pad_bottom' ),
99+ node .get_attr ('pad_left' ),
100+ node .get_attr ('pad_right' ),
101+ ),
102+ )
103+
104+ node .set_attr ('dw_line_buffer_codegen' , Source (dw_code_str ))
105+
106+ pw_code_str = node .model .config .backend .generate_conv2d_line_buffer_fn (
107+ str (node .get_attr ('index' )) + '_pw' ,
108+ node .get_attr ('n_partitions' ),
109+ node .get_output_variable ().shape [0 ],
110+ node .get_output_variable ().shape [1 ],
111+ node .get_input_variable ().shape [2 ],
112+ kernel = (1 , 1 ),
113+ )
114+
115+ node .set_attr ('pw_line_buffer_codegen' , Source (pw_code_str ))
0 commit comments