@@ -259,6 +259,13 @@ def _mypad(
259259 return x [:, :, :, xe ]
260260
261261
262+ def _next_power_of_two (x : int , max_val : int = 128 ) -> int :
263+ n = 1
264+ while n < x and n < max_val :
265+ n *= 2
266+ return n
267+
268+
262269def _conv2d (
263270 x : cp .ndarray ,
264271 w : np .ndarray ,
@@ -271,23 +278,16 @@ def _conv2d(
271278 co , _ , hk , wk = w .shape
272279 ho = int (np .floor (1 + (hi - hk ) / stride [0 ]))
273280 wo = int (np .floor (1 + (wi - wk ) / stride [1 ]))
274- chunk = ci // groups
275- chunko = co // groups
276281 out_shape = [b , co , ho , wo ]
277- sum_out_shape = [b , chunko , ho * stride [0 ] // stride [0 ], wo ]
278282 if mem_stack :
279- # sum_out shape is counted twice, because the size of the temporary multiplication result
280- mem_stack .malloc ((2 * np .prod (sum_out_shape ) + w .size ) * np .float32 ().itemsize )
281283 mem_stack .malloc (np .prod (out_shape ) * np .float32 ().itemsize )
282- # everything but out gets freed
283- mem_stack .free ((2 * np .prod (sum_out_shape ) + w .size ) * np .float32 ().itemsize )
284284 return out_shape
285285
286286 out = cp .zeros (out_shape , dtype = "float32" )
287287 w = cp .asarray (w )
288288 x = cp .expand_dims (x , axis = 1 )
289289 w = np .expand_dims (w , axis = 0 )
290- symbol_names = [f"double_convolution_x< { max ( hk , wk ) } >" , f"double_convolution_y< { max ( hk , wk ) } >" ]
290+ symbol_names = [f"grouped_convolution_x< { wk } >" , f"grouped_convolution_y< { hk } >" ]
291291 module = load_cuda_module ("remove_stripe_fw" , name_expressions = symbol_names )
292292 dim_x = out .shape [- 1 ]
293293 dim_y = out .shape [- 2 ]
@@ -298,21 +298,21 @@ def _conv2d(
298298 out_stride_z = out .strides [0 ] // x .dtype .itemsize
299299 out_stride_group = out .strides [1 ] // x .dtype .itemsize
300300
301- block_x = 128
301+ block_x = _next_power_of_two ( dim_x )
302302 block_dim = (block_x , 1 , 1 )
303303 grid_x = (dim_x + block_x - 1 ) // block_x
304304 grid_dim = (grid_x , dim_y , dim_z )
305305
306306 if groups == 1 :
307- double_convolution_kernel_x = module .get_function (symbol_names [0 ])
308- double_convolution_kernel_x (grid_dim , block_dim ,
307+ grouped_convolution_kernel_x = module .get_function (symbol_names [0 ])
308+ grouped_convolution_kernel_x (grid_dim , block_dim ,
309309 (dim_x , dim_y , dim_z , x , in_stride_x , in_stride_y ,
310310 in_stride_z , out , out_stride_z , out_stride_group , w ))
311311 return out
312312
313- double_convolution_kernel_y = module .get_function (symbol_names [1 ])
313+ grouped_convolution_kernel_y = module .get_function (symbol_names [1 ])
314314 in_stride_group = x .strides [2 ] // x .dtype .itemsize
315- double_convolution_kernel_y (grid_dim , block_dim ,
315+ grouped_convolution_kernel_y (grid_dim , block_dim ,
316316 (dim_x , dim_y , dim_z , x , in_stride_x , in_stride_y ,
317317 in_stride_z , in_stride_group , out , out_stride_z ,
318318 out_stride_group , w ))
@@ -334,15 +334,11 @@ def _conv_transpose2d(
334334
335335 hi = (ho - 1 ) * stride [0 ] + hk
336336 wi = (wo - 1 ) * stride [1 ] + wk
337- chunk = ci // groups
338- chunko = co // groups
339337 out_shape = [b , ci , hi , wi ]
340338 if mem_stack :
341- tmp_weighted_shape = (b , co , ho , wo )
342339 # The trouble here is that we allocate more than the returned size
343- mem_stack .malloc (np .prod (out_shape ) * np .float32 ().itemsize )
344- mem_stack .malloc ((np .prod (tmp_weighted_shape ) + w .size ) * np .float32 ().itemsize )
345- mem_stack .free ((np .prod (tmp_weighted_shape ) + w .size ) * np .float32 ().itemsize )
340+ out_actual_bytes = np .prod (out_shape ) * np .float32 ().itemsize
341+ mem_stack .malloc (out_actual_bytes )
346342 if pad != 0 :
347343 new_out_shape = [
348344 out_shape [0 ],
@@ -357,19 +353,35 @@ def _conv_transpose2d(
357353
358354 out = cp .zeros (out_shape , dtype = "float32" )
359355 w = cp .asarray (w )
360- for g in range (groups ):
361- for ii in range (hk ):
362- for jj in range (wk ):
363- x_windows = x [:, g * chunko : (g + 1 ) * chunko ]
364- out [
365- :,
366- g * chunk : (g + 1 ) * chunk ,
367- ii : ho * stride [0 ] + ii : stride [0 ],
368- jj : wo * stride [1 ] + jj : stride [1 ],
369- ] += (
370- x_windows
371- * w [g * chunko : (g + 1 ) * chunko , :, ii : ii + 1 , jj : jj + 1 ]
372- )
356+
357+ symbol_names = [f"transposed_convolution_x<{ wk } >" , f"transposed_convolution_y<{ hk } >" ]
358+ module = load_cuda_module ("remove_stripe_fw" , name_expressions = symbol_names )
359+ dim_x = out .shape [- 1 ]
360+ dim_y = out .shape [- 2 ]
361+ dim_z = out .shape [0 ]
362+ in_dim_x = x .shape [- 1 ]
363+ in_dim_y = x .shape [- 2 ]
364+ in_stride_y = x .strides [- 2 ] // x .dtype .itemsize
365+ in_stride_z = x .strides [0 ] // x .dtype .itemsize
366+
367+ block_x = _next_power_of_two (dim_x )
368+ block_dim = (block_x , 1 , 1 )
369+ grid_x = (dim_x + block_x - 1 ) // block_x
370+ grid_dim = (grid_x , dim_y , dim_z )
371+
372+ if wk > 1 :
373+ transposed_convolution_kernel_x = module .get_function (symbol_names [0 ])
374+ transposed_convolution_kernel_x (grid_dim , block_dim ,
375+ (dim_x , dim_y , dim_z , x ,
376+ in_dim_x , in_stride_y , in_stride_z , w , out ))
377+ elif hk > 1 :
378+ transposed_convolution_kernel_y = module .get_function (symbol_names [1 ])
379+ transposed_convolution_kernel_y (grid_dim , block_dim ,
380+ (dim_x , dim_y , dim_z , x ,
381+ in_dim_y , in_stride_y , in_stride_z , w , out ))
382+ else :
383+ assert (False )
384+
373385 if pad != 0 :
374386 out = out [:, :, pad [0 ] : out .shape [2 ] - pad [0 ], pad [1 ] : out .shape [3 ] - pad [1 ]]
375387 return cp .ascontiguousarray (out )
0 commit comments