Skip to content

Commit 3d9d1f1

Browse files
committed
Custom kernels in _conv_transpose2d
1 parent bebfbe0 commit 3d9d1f1

File tree

2 files changed

+121
-36
lines changed

2 files changed

+121
-36
lines changed

httomolibgpu/cuda_kernels/remove_stripe_fw.cu

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
template<int WSize>
2-
__global__ void double_convolution_x(
2+
__global__ void grouped_convolution_x(
33
int dim_x,
44
int dim_y,
55
int dim_z,
@@ -32,12 +32,12 @@ __global__ void double_convolution_x(
3232
acc += w[w_idx] * in[in_idx];
3333
}
3434
const int out_idx = g_thd_x + g_thd_y * dim_x + g_thd_z * out_stride_z + i * out_stride_group;
35-
out[out_idx] += acc;
35+
out[out_idx] = acc;
3636
}
3737
}
3838

3939
template<int WSize>
40-
__global__ void double_convolution_y(
40+
__global__ void grouped_convolution_y(
4141
int dim_x,
4242
int dim_y,
4343
int dim_z,
@@ -75,8 +75,81 @@ __global__ void double_convolution_y(
7575
acc += w[w_idx] * in[in_idx];
7676
}
7777
const int out_idx = g_thd_x + g_thd_y * dim_x + g_thd_z * out_stride_z + (out_groups * group + i) * out_stride_group;
78-
out[out_idx] += acc;
78+
out[out_idx] = acc;
7979
}
8080
}
8181
}
8282

83+
template<int WSize>
84+
__global__ void transposed_convolution_x(
85+
int dim_x,
86+
int dim_y,
87+
int dim_z,
88+
const float* in,
89+
int in_dim_x,
90+
int in_stride_y,
91+
int in_stride_z,
92+
const float* w,
93+
float* out
94+
)
95+
{
96+
const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x;
97+
const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y;
98+
const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z;
99+
if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z)
100+
{
101+
return;
102+
}
103+
104+
constexpr int item_out_stride = 2;
105+
float acc = 0.F;
106+
for (int i = 0; i < WSize; ++i)
107+
{
108+
const int in_x = (g_thd_x - i) / item_out_stride;
109+
const int in_x_mod = (g_thd_x - i) % item_out_stride;
110+
if (in_x_mod == 0 && in_x >= 0 && in_x < in_dim_x)
111+
{
112+
const int in_idx = in_x + g_thd_y * in_stride_y + g_thd_z * in_stride_z;
113+
acc += in[in_idx] * w[i];
114+
}
115+
}
116+
const int out_idx = g_thd_x + dim_x * g_thd_y + dim_x * dim_y * g_thd_z;
117+
out[out_idx] = acc;
118+
}
119+
120+
template<int WSize>
121+
__global__ void transposed_convolution_y(
122+
int dim_x,
123+
int dim_y,
124+
int dim_z,
125+
const float* in,
126+
int in_dim_y,
127+
int in_stride_y,
128+
int in_stride_z,
129+
const float* w,
130+
float* out
131+
)
132+
{
133+
const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x;
134+
const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y;
135+
const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z;
136+
if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z)
137+
{
138+
return;
139+
}
140+
141+
constexpr int item_out_stride = 2;
142+
float acc = 0.F;
143+
for (int i = 0; i < WSize; ++i)
144+
{
145+
const int in_y = (g_thd_y - i) / item_out_stride;
146+
const int in_y_mod = (g_thd_y - i) % item_out_stride;
147+
if (in_y_mod == 0 && in_y >= 0 && in_y < in_dim_y)
148+
{
149+
const int in_idx = g_thd_x + in_y * in_stride_y + g_thd_z * in_stride_z;
150+
acc += in[in_idx] * w[i];
151+
}
152+
}
153+
const int out_idx = g_thd_x + dim_x * g_thd_y + dim_x * dim_y * g_thd_z;
154+
out[out_idx] = acc;
155+
}

httomolibgpu/prep/stripe.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ def _mypad(
259259
return x[:, :, :, xe]
260260

261261

262+
def _next_power_of_two(x: int, max_val: int = 128) -> int:
263+
n = 1
264+
while n < x and n < max_val:
265+
n *= 2
266+
return n
267+
268+
262269
def _conv2d(
263270
x: cp.ndarray,
264271
w: np.ndarray,
@@ -271,23 +278,16 @@ def _conv2d(
271278
co, _, hk, wk = w.shape
272279
ho = int(np.floor(1 + (hi - hk) / stride[0]))
273280
wo = int(np.floor(1 + (wi - wk) / stride[1]))
274-
chunk = ci // groups
275-
chunko = co // groups
276281
out_shape = [b, co, ho, wo]
277-
sum_out_shape = [b, chunko, ho * stride[0] // stride[0], wo]
278282
if mem_stack:
279-
# sum_out shape is counted twice, because the size of the temporary multiplication result
280-
mem_stack.malloc((2 * np.prod(sum_out_shape) + w.size) * np.float32().itemsize)
281283
mem_stack.malloc(np.prod(out_shape) * np.float32().itemsize)
282-
# everything but out gets freed
283-
mem_stack.free((2 * np.prod(sum_out_shape) + w.size) * np.float32().itemsize)
284284
return out_shape
285285

286286
out = cp.zeros(out_shape, dtype="float32")
287287
w = cp.asarray(w)
288288
x = cp.expand_dims(x, axis=1)
289289
w = np.expand_dims(w, axis=0)
290-
symbol_names = [f"double_convolution_x<{max(hk, wk)}>", f"double_convolution_y<{max(hk, wk)}>"]
290+
symbol_names = [f"grouped_convolution_x<{wk}>", f"grouped_convolution_y<{hk}>"]
291291
module = load_cuda_module("remove_stripe_fw", name_expressions=symbol_names)
292292
dim_x = out.shape[-1]
293293
dim_y = out.shape[-2]
@@ -298,21 +298,21 @@ def _conv2d(
298298
out_stride_z = out.strides[0] // x.dtype.itemsize
299299
out_stride_group = out.strides[1] // x.dtype.itemsize
300300

301-
block_x = 128
301+
block_x = _next_power_of_two(dim_x)
302302
block_dim = (block_x, 1, 1)
303303
grid_x = (dim_x + block_x - 1) // block_x
304304
grid_dim = (grid_x, dim_y, dim_z)
305305

306306
if groups == 1:
307-
double_convolution_kernel_x = module.get_function(symbol_names[0])
308-
double_convolution_kernel_x(grid_dim, block_dim,
307+
grouped_convolution_kernel_x = module.get_function(symbol_names[0])
308+
grouped_convolution_kernel_x(grid_dim, block_dim,
309309
(dim_x, dim_y, dim_z, x, in_stride_x, in_stride_y,
310310
in_stride_z, out, out_stride_z, out_stride_group, w))
311311
return out
312312

313-
double_convolution_kernel_y = module.get_function(symbol_names[1])
313+
grouped_convolution_kernel_y = module.get_function(symbol_names[1])
314314
in_stride_group = x.strides[2] // x.dtype.itemsize
315-
double_convolution_kernel_y(grid_dim, block_dim,
315+
grouped_convolution_kernel_y(grid_dim, block_dim,
316316
(dim_x, dim_y, dim_z, x, in_stride_x, in_stride_y,
317317
in_stride_z, in_stride_group, out, out_stride_z,
318318
out_stride_group, w))
@@ -334,15 +334,11 @@ def _conv_transpose2d(
334334

335335
hi = (ho - 1) * stride[0] + hk
336336
wi = (wo - 1) * stride[1] + wk
337-
chunk = ci // groups
338-
chunko = co // groups
339337
out_shape = [b, ci, hi, wi]
340338
if mem_stack:
341-
tmp_weighted_shape = (b, co, ho, wo)
342339
# The trouble here is that we allocate more than the returned size
343-
mem_stack.malloc(np.prod(out_shape) * np.float32().itemsize)
344-
mem_stack.malloc((np.prod(tmp_weighted_shape) + w.size) * np.float32().itemsize)
345-
mem_stack.free((np.prod(tmp_weighted_shape) + w.size) * np.float32().itemsize)
340+
out_actual_bytes = np.prod(out_shape) * np.float32().itemsize
341+
mem_stack.malloc(out_actual_bytes)
346342
if pad != 0:
347343
new_out_shape = [
348344
out_shape[0],
@@ -357,19 +353,35 @@ def _conv_transpose2d(
357353

358354
out = cp.zeros(out_shape, dtype="float32")
359355
w = cp.asarray(w)
360-
for g in range(groups):
361-
for ii in range(hk):
362-
for jj in range(wk):
363-
x_windows = x[:, g * chunko : (g + 1) * chunko]
364-
out[
365-
:,
366-
g * chunk : (g + 1) * chunk,
367-
ii : ho * stride[0] + ii : stride[0],
368-
jj : wo * stride[1] + jj : stride[1],
369-
] += (
370-
x_windows
371-
* w[g * chunko : (g + 1) * chunko, :, ii : ii + 1, jj : jj + 1]
372-
)
356+
357+
symbol_names = [f"transposed_convolution_x<{wk}>", f"transposed_convolution_y<{hk}>"]
358+
module = load_cuda_module("remove_stripe_fw", name_expressions=symbol_names)
359+
dim_x = out.shape[-1]
360+
dim_y = out.shape[-2]
361+
dim_z = out.shape[0]
362+
in_dim_x = x.shape[-1]
363+
in_dim_y = x.shape[-2]
364+
in_stride_y = x.strides[-2] // x.dtype.itemsize
365+
in_stride_z = x.strides[0] // x.dtype.itemsize
366+
367+
block_x = _next_power_of_two(dim_x)
368+
block_dim = (block_x, 1, 1)
369+
grid_x = (dim_x + block_x - 1) // block_x
370+
grid_dim = (grid_x, dim_y, dim_z)
371+
372+
if wk > 1:
373+
transposed_convolution_kernel_x = module.get_function(symbol_names[0])
374+
transposed_convolution_kernel_x(grid_dim, block_dim,
375+
(dim_x, dim_y, dim_z, x,
376+
in_dim_x, in_stride_y, in_stride_z, w, out))
377+
elif hk > 1:
378+
transposed_convolution_kernel_y = module.get_function(symbol_names[1])
379+
transposed_convolution_kernel_y(grid_dim, block_dim,
380+
(dim_x, dim_y, dim_z, x,
381+
in_dim_y, in_stride_y, in_stride_z, w, out))
382+
else:
383+
assert(False)
384+
373385
if pad != 0:
374386
out = out[:, :, pad[0] : out.shape[2] - pad[0], pad[1] : out.shape[3] - pad[1]]
375387
return cp.ascontiguousarray(out)

0 commit comments

Comments
 (0)