@@ -38,25 +38,71 @@ class ConvConfig:
3838 output_dtype : str
3939
4040 def get_name (self ) -> str :
41- return self .OP + "_" + f"{ self .N } x{ self .H } x{ self .W } x{ self .C } x{ self .P } x{ self .Q } x{ self .F } " + "_" + f"{ self .input_dtype } x{ self .input_dtype } x{ self .output_dtype } " + "_stride" + str (self .S )
42-
41+ return (
42+ self .OP
43+ + "_"
44+ + f"{ self .N } x{ self .H } x{ self .W } x{ self .C } x{ self .P } x{ self .Q } x{ self .F } "
45+ + "_"
46+ + f"{ self .input_dtype } x{ self .input_dtype } x{ self .output_dtype } "
47+ + "_stride"
48+ + str (self .S )
49+ )
50+
4351 def get_img_shape (self ) -> str :
4452 if "nhwc" in self .OP :
4553 in_h = self .H * self .S + self .P - 1
4654 in_w = self .W * self .S + self .Q - 1
47- return str (self .N ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + str (self .C ) + "x" + self .input_dtype
55+ return (
56+ str (self .N )
57+ + "x"
58+ + str (in_h )
59+ + "x"
60+ + str (in_w )
61+ + "x"
62+ + str (self .C )
63+ + "x"
64+ + self .input_dtype
65+ )
4866 if "nchw" in self .OP :
4967 in_h = self .H * self .S + self .P - 1
5068 in_w = self .W * self .S + self .Q - 1
51- return str (self .N ) + "x" + str (self .C ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + self .input_dtype
52-
53-
69+ return (
70+ str (self .N )
71+ + "x"
72+ + str (self .C )
73+ + "x"
74+ + str (in_h )
75+ + "x"
76+ + str (in_w )
77+ + "x"
78+ + self .input_dtype
79+ )
80+
5481 def get_kernel_shape (self ) -> str :
5582 if "nhwc" in self .OP :
56- return str (self .P ) + "x" + str (self .Q ) + "x" + str (self .C ) + "x" + str (self .F ) + "x" + self .input_dtype
83+ return (
84+ str (self .P )
85+ + "x"
86+ + str (self .Q )
87+ + "x"
88+ + str (self .C )
89+ + "x"
90+ + str (self .F )
91+ + "x"
92+ + self .input_dtype
93+ )
5794 if "nchw" in self .OP :
58- return str (self .F ) + "x" + str (self .C ) + "x" + str (self .P ) + "x" + str (self .Q ) + "x" + self .input_dtype
59-
95+ return (
96+ str (self .F )
97+ + "x"
98+ + str (self .C )
99+ + "x"
100+ + str (self .P )
101+ + "x"
102+ + str (self .Q )
103+ + "x"
104+ + self .input_dtype
105+ )
60106
61107 def get_byte_count (self ) -> int :
62108 dtype_bits_map = {
@@ -80,7 +126,13 @@ def get_byte_count(self) -> int:
80126 k_height = self .P
81127 byte_count = (
82128 (batch * input_channels * in_w * in_h * bytes_per_input )
83- + (batch * output_channels * output_width * output_height * bytes_per_output )
129+ + (
130+ batch
131+ * output_channels
132+ * output_width
133+ * output_height
134+ * bytes_per_output
135+ )
84136 + (k_width * k_height * input_channels * output_channels * bytes_per_input )
85137 )
86138 return byte_count
@@ -100,6 +152,7 @@ def get_flops(self) -> int:
100152 flops = operation_per_pixel * output_pixels_per_batch * batch
101153 return flops
102154
155+
103156def generate_mlir (config : ConvConfig ):
104157 n = config .N
105158 h = config .H
@@ -116,17 +169,77 @@ def generate_mlir(config: ConvConfig):
116169 in_w = str (int (w ) * int (stride ) + int (q ) - 1 )
117170 if "nhwc" in operation :
118171 conv_type = "nhwc_hwcf"
119- lhs = str (n ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + str (c ) + "x" + str (elem_types [0 ])
120- rhs = str (p ) + "x" + str (q ) + "x" + str (c ) + "x" + str (f ) + "x" + str (elem_types [1 ])
121- out = str (n ) + "x" + str (h ) + "x" + str (w ) + "x" + str (f ) + "x" + str (elem_types [2 ])
172+ lhs = (
173+ str (n )
174+ + "x"
175+ + str (in_h )
176+ + "x"
177+ + str (in_w )
178+ + "x"
179+ + str (c )
180+ + "x"
181+ + str (elem_types [0 ])
182+ )
183+ rhs = (
184+ str (p )
185+ + "x"
186+ + str (q )
187+ + "x"
188+ + str (c )
189+ + "x"
190+ + str (f )
191+ + "x"
192+ + str (elem_types [1 ])
193+ )
194+ out = (
195+ str (n )
196+ + "x"
197+ + str (h )
198+ + "x"
199+ + str (w )
200+ + "x"
201+ + str (f )
202+ + "x"
203+ + str (elem_types [2 ])
204+ )
122205 if "nchw" in operation :
123206 conv_type = "nchw_fchw"
124- lhs = str (n ) + "x" + str (c ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + str (elem_types [0 ])
125- rhs = str (f ) + "x" + str (c ) + "x" + str (p ) + "x" + str (q ) + "x" + str (elem_types [1 ])
126- out = str (n ) + "x" + str (f ) + "x" + str (h ) + "x" + str (w ) + "x" + str (elem_types [2 ])
207+ lhs = (
208+ str (n )
209+ + "x"
210+ + str (c )
211+ + "x"
212+ + str (in_h )
213+ + "x"
214+ + str (in_w )
215+ + "x"
216+ + str (elem_types [0 ])
217+ )
218+ rhs = (
219+ str (f )
220+ + "x"
221+ + str (c )
222+ + "x"
223+ + str (p )
224+ + "x"
225+ + str (q )
226+ + "x"
227+ + str (elem_types [1 ])
228+ )
229+ out = (
230+ str (n )
231+ + "x"
232+ + str (f )
233+ + "x"
234+ + str (h )
235+ + "x"
236+ + str (w )
237+ + "x"
238+ + str (elem_types [2 ])
239+ )
127240 one = "1"
128241 zero = "0"
129- if ( elem_types [0 ][0 ] == "f" ) :
242+ if elem_types [0 ][0 ] == "f" :
130243 one = "1.0"
131244 zero = "0.0"
132245 conv_template = CONV
0 commit comments