Skip to content

Commit 4ebb419

Browse files
authored
Merge pull request #421 from mloubout/master
support complex input for upsample
2 parents b02b8c2 + 22ef274 commit 4ebb419

File tree

2 files changed

+42
-18
lines changed

2 files changed

+42
-18
lines changed

src/upsample.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,14 @@ upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) =
149149
function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
150150
size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
151151
in_w, channels, batches = size(input)
152+
RT = real(T)
152153
# treat batch and channel dimension as one for better parallelization granularity
153154
channels *= batches
154155
out_w, _, _ = size(output)
155156
output_slice_size = out_w
156157

157-
# T() and // so that we can handle rationals (super slow)
158-
width_scale = T((in_w - 1) // (out_w - 1))
158+
#real(T)() and // so that we can handle rationals (super slow)
159+
width_scale = RT((in_w - 1) // (out_w - 1))
159160

160161
@inline idx(c, w) = c * in_w + w + 1
161162

@@ -175,14 +176,15 @@ end
175176
function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
176177
size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
177178
in_w, in_h, channels, batches = size(input)
179+
RT = real(T)
178180
# treat batch and channel dimension as one for better parallelization granularity
179181
channels *= batches
180182
out_w, out_h, _, _ = size(output)
181183
output_slice_size = out_h * out_w
182184

183-
# T() and // so that we can handle rationals (super slow)
184-
width_scale = T((in_w - 1) // (out_w - 1))
185-
height_scale = T((in_h - 1) // (out_h - 1))
185+
#real(T)() and // so that we can handle rationals (super slow)
186+
width_scale = RT((in_w - 1) // (out_w - 1))
187+
height_scale = RT((in_h - 1) // (out_h - 1))
186188

187189
@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1
188190

@@ -208,15 +210,16 @@ end
208210
function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
209211
size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
210212
in_w, in_h, in_d, channels, batches = size(input)
213+
RT = real(T)
211214
# treat batch and channel dimension as one for better parallelization granularity
212215
channels *= batches
213216
out_w, out_h, out_d, _, _ = size(output)
214217
output_slice_size = out_h * out_w * out_d
215218

216-
# T() and // so that we can handle rationals (super slow)
217-
width_scale = T((in_w - 1) // (out_w - 1))
218-
height_scale = T((in_h - 1) // (out_h - 1))
219-
depth_scale = T((in_d - 1) // (out_d - 1))
219+
#real(T)() and // so that we can handle rationals (super slow)
220+
width_scale = RT((in_w - 1) // (out_w - 1))
221+
height_scale = RT((in_h - 1) // (out_h - 1))
222+
depth_scale = RT((in_d - 1) // (out_d - 1))
220223

221224
@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1
222225

@@ -268,13 +271,13 @@ end
268271
function ∇upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
269272
size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
270273
in_w, channels, batches = size(dx)
271-
274+
RT = real(T)
272275
# treat batch and channel dimension as one for better parallelization granularity
273276
channels *= batches
274277
out_w, _, _ = size(Δ)
275278
output_slice_size = out_w
276279

277-
width_scale = T((in_w - 1) // (out_w - 1))
280+
width_scale = RT((in_w - 1) // (out_w - 1))
278281

279282
@inline idx(c, w) = c * in_w + w + 1
280283

@@ -294,14 +297,14 @@ end
294297
function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
295298
size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
296299
in_w, in_h, channels, batches = size(dx)
297-
300+
RT = real(T)
298301
# treat batch and channel dimension as one for better parallelization granularity
299302
channels *= batches
300303
out_w, out_h, _, _ = size(Δ)
301304
output_slice_size = out_h * out_w
302305

303-
width_scale = T((in_w - 1) // (out_w - 1))
304-
height_scale = T((in_h - 1) // (out_h - 1))
306+
width_scale = RT((in_w - 1) // (out_w - 1))
307+
height_scale = RT((in_h - 1) // (out_h - 1))
305308

306309
@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1
307310

@@ -326,15 +329,16 @@ end
326329
function ∇upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
327330
size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))")
328331
in_w, in_h, in_d, channels, batches = size(dx)
332+
RT = real(T)
329333
# treat batch and channel dimension as one for better parallelization granularity
330334
channels *= batches
331335
out_w, out_h, out_d, _, _ = size(Δ)
332336
output_slice_size = out_h * out_w * out_d
333337

334-
# T() and // so that we can handle rationals (super slow)
335-
width_scale = T((in_w - 1) // (out_w - 1))
336-
height_scale = T((in_h - 1) // (out_h - 1))
337-
depth_scale = T((in_d - 1) // (out_d - 1))
338+
#real(T)() and // so that we can handle rationals (super slow)
339+
width_scale = RT((in_w - 1) // (out_w - 1))
340+
height_scale = RT((in_h - 1) // (out_h - 1))
341+
depth_scale = RT((in_d - 1) // (out_d - 1))
338342

339343
@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1
340344

test/upsample.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,23 @@ end
172172
gradtest(x -> pixel_shuffle(x, r), x)
173173
end
174174
end
175+
176+
@testset "Complex-valued upsample" begin
177+
for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear])
178+
for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest])
179+
x = randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1)
180+
181+
upsize = (8, 16, 24)[1:d]
182+
xup = interp(x, k)
183+
@test size(xup)[1:d] == upsize
184+
@test real(xup) == interp(real(x), k)
185+
@test imag(xup) == interp(imag(x), k)
186+
187+
upsize = (8,24,48)[1:d]
188+
xup = interp(x; size=upsize)
189+
@test size(xup)[1:d] == upsize
190+
@test real(xup) == interp(real(x), size=upsize)
191+
@test imag(xup) == interp(imag(x), size=upsize)
192+
end
193+
end
194+
end

0 commit comments

Comments
 (0)