|
1 | 1 | ## This file contains direct Julia implementations of 2d and 3d convolutions
|
| 2 | +using Base.Threads |
2 | 3 |
|
3 | 4 | # Helper functions for restricting x/w overreach
|
4 | 5 | function clamp_lo(x, w)
|
@@ -57,50 +58,87 @@ function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
|
57 | 58 | stride_w, stride_h, stride_d = stride(cdims)
|
58 | 59 | out_width, out_height, out_depth = output_size(cdims)
|
59 | 60 |
|
60 |
| - # If we're doing crosscorr instead of conv, then don't bother to flip `w` |
61 |
| - if !flipkernel(cdims) |
62 |
| - w = w[end:-1:1, end:-1:1, end:-1:1, :, :] |
63 |
| - end |
64 |
| - |
| 61 | + # Create a method that, at compile-time, determines how we're going to index into `w` |
| 62 | + kproj(k, M, cdims::ConvDims{N,S,P,D,true}) where {N, S, P, D} = k |
| 63 | + kproj(k, M, cdims::ConvDims{N,S,P,D,false}) where {N, S, P, D} = M - k + 1 |
| 64 | + |
65 | 65 | # A helper function to project from output (w, h) to input (input_w, input_h)
|
66 |
| - @inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1 |
| 66 | + project(idx, stride, pad) = (idx - 1)*stride - pad + 1 |
67 | 67 |
|
68 |
| - # explicit formulation of convolution. Oh hoisting gods, hear my plea. |
69 |
| - @inbounds for batch in 1:size(x)[end], |
| 68 | + # Use `calc_padding_regions` to determine where we do or don't need to worry about padding |
| 69 | + padded_regions, central_region = calc_padding_regions(cdims) |
| 70 | + |
| 71 | + # Start with the central region |
| 72 | + w_region, h_region, d_region = central_region |
| 73 | + @inbounds for batch in 1:size(x, 5), |
| 74 | + c_out in 1:out_c, |
| 75 | + d_idx in d_region, |
| 76 | + h_idx in h_region, |
| 77 | + w_idx in w_region |
| 78 | + |
| 79 | + # Since we're in the central region, we don't need to worry about clamping |
| 80 | + dotprod = yT(0) |
| 81 | + for c_in in 1:channels_in(cdims), |
| 82 | + kd in 1:kernel_d, |
| 83 | + kh in 1:kernel_h, |
| 84 | + kw in 1:kernel_w |
| 85 | + |
| 86 | + # Hoist me, you coward. |
| 87 | + x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d |
| 88 | + x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h |
| 89 | + x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w |
| 90 | + |
| 91 | + x_val = x[x_w, x_h, x_d, c_in, batch] |
| 92 | + w_val = w[kproj(kw, kernel_w, cdims), |
| 93 | + kproj(kh, kernel_h, cdims), |
| 94 | + kproj(kd, kernel_d, cdims), |
| 95 | + c_in, c_out] |
| 96 | + dotprod = muladd(x_val, w_val, dotprod) |
| 97 | + end |
| 98 | + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] |
| 99 | + end |
| 100 | + |
| 101 | + # Next, do potentially-padded regions: |
| 102 | + @inbounds for (w_region, h_region, d_region) in padded_regions, |
| 103 | + batch in 1:size(x, 5), |
70 | 104 | c_out in 1:out_c,
|
71 |
| - d_idx in 1:out_depth, |
72 |
| - h_idx in 1:out_height, |
73 |
| - w_idx in 1:out_width |
74 |
| - |
75 |
| - # Starting points of the window of x we're going to grab |
76 |
| - x_w = project(w_idx, stride_w, pad_w_lo) |
77 |
| - x_h = project(h_idx, stride_h, pad_h_lo) |
78 |
| - x_d = project(d_idx, stride_d, pad_d_lo) |
79 |
| - |
80 |
| - # Grow that starting point into ranges |
81 |
| - x_widxs = x_w .+ (0:dil_w:(dil_w*kernel_w-1)) |
82 |
| - x_hidxs = x_h .+ (0:dil_h:(dil_h*kernel_h-1)) |
83 |
| - x_didxs = x_d .+ (0:dil_d:(dil_d*kernel_d-1)) |
84 |
| - w_widxs = 1:kernel_w |
85 |
| - w_hidxs = 1:kernel_h |
86 |
| - w_didxs = 1:kernel_d |
87 |
| - |
88 |
| - # Clamp the ranges to simulate padding |
89 |
| - x_widxs, w_widxs = clamp_lo(x_widxs, w_widxs) |
90 |
| - x_widxs, w_widxs = clamp_hi(x_widxs, w_widxs, width) |
91 |
| - x_hidxs, w_hidxs = clamp_lo(x_hidxs, w_hidxs) |
92 |
| - x_hidxs, w_hidxs = clamp_hi(x_hidxs, w_hidxs, height) |
93 |
| - x_didxs, w_didxs = clamp_lo(x_didxs, w_didxs) |
94 |
| - x_didxs, w_didxs = clamp_hi(x_didxs, w_didxs, depth) |
95 |
| - |
96 |
| - # Grab our slices |
97 |
| - x_slice = view(x, x_widxs, x_hidxs, x_didxs, :, batch) |
98 |
| - w_slice = view(w, w_widxs, w_hidxs, w_didxs, :, c_out) |
99 |
| - |
100 |
| - # Do the dotproduct dance, then weight by alpha/beta and git 'er done |
101 |
| - dotprod = sum(x_slice .* w_slice) |
102 |
| - y[w_idx, h_idx, d_idx, c_out, batch] = alpha*convert(yT, dotprod) + |
103 |
| - beta*y[w_idx, h_idx, d_idx, c_out, batch] |
| 105 | + d_idx in d_region, |
| 106 | + h_idx in h_region, |
| 107 | + w_idx in w_region |
| 108 | + |
| 109 | + # Probe for out-of-bounds accesses on `x` and `continue` if we hit one |
| 110 | + dotprod = yT(0) |
| 111 | + for c_in in 1:channels_in(cdims), |
| 112 | + kd in 1:kernel_d |
| 113 | + |
| 114 | + x_d = project(d_idx, stride_d, pad_d_lo) + (kd - 1)*dil_d |
| 115 | + if x_d <= 0 || x_d > depth |
| 116 | + continue |
| 117 | + end |
| 118 | + |
| 119 | + for kh in 1:kernel_h |
| 120 | + x_h = project(h_idx, stride_h, pad_h_lo) + (kh - 1)*dil_h |
| 121 | + if x_h <= 0 || x_h > height |
| 122 | + continue |
| 123 | + end |
| 124 | + |
| 125 | + for kw in 1:kernel_w |
| 126 | + x_w = project(w_idx, stride_w, pad_w_lo) + (kw - 1)*dil_w |
| 127 | + if x_w <= 0 || x_w > width |
| 128 | + continue |
| 129 | + end |
| 130 | + |
| 131 | + x_val = x[x_w, x_h, x_d, c_in, batch] |
| 132 | + w_val = w[kproj(kw, kernel_w, cdims), |
| 133 | + kproj(kh, kernel_h, cdims), |
| 134 | + kproj(kd, kernel_d, cdims), |
| 135 | + c_in, c_out] |
| 136 | + dotprod = muladd(x_val, w_val, dotprod) |
| 137 | + end |
| 138 | + end |
| 139 | + end |
| 140 | + |
| 141 | + y[w_idx, h_idx, d_idx, c_out, batch] = alpha*dotprod + beta*y[w_idx, h_idx, d_idx, c_out, batch] |
104 | 142 | end
|
105 | 143 |
|
106 | 144 | return y
|
|
0 commit comments