Skip to content

Commit 507f744

Browse files
Merge pull request #414 from maxfreu/deflate-upsampling
partly revert changes to stay compatible w NNlibCUDA
2 parents b564601 + ddc1e03 commit 507f744

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

src/upsample.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,19 @@ function upsample_linear(x::AbstractArray{T,<:Any}; size) where T<:Integer
134134
return round.(T, res)
135135
end
136136

137+
# compatibility layer for old versions of NNlibCUDA
138+
# old versions overload upsample_linear_wcn, new versions overload upsample_linear_kernel
139+
# can be removed from NNlib 0.9, i.e. revert https://github.com/FluxML/NNlib.jl/pull/414
140+
# IF https://github.com/FluxML/NNlibCUDA.jl/pull/49 has been merged
141+
upsample_linear_kernel!(y::AbstractArray{<:Any,3}, x::AbstractArray{<:Any,3}) = upsample_linear_wcn!(y,x)
142+
upsample_linear_kernel!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = upsample_bilinear_whcn!(y,x)
143+
upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) = upsample_trilinear_whdcn!(y,x)
144+
∇upsample_linear_kernel!(y::AbstractArray{<:Any,3}, x::AbstractArray{<:Any,3}) = ∇upsample_linear_wcn!(y,x)
145+
∇upsample_linear_kernel!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = ∇upsample_bilinear_whcn!(y,x)
146+
∇upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) = ∇upsample_trilinear_whdcn!(y,x)
147+
137148
# linearly upsamples first dim of 3D array
138-
function upsample_linear_kernel!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
149+
function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
139150
size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
140151
in_w, channels, batches = size(input)
141152
# treat batch and channel dimension as one for better parallelization granularity
@@ -161,7 +172,7 @@ end
161172

162173
# bilinear
163174
# linearly upsamples first two dims of 4D array
164-
function upsample_linear_kernel!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
175+
function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
165176
size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
166177
in_w, in_h, channels, batches = size(input)
167178
# treat batch and channel dimension as one for better parallelization granularity
@@ -194,7 +205,7 @@ end
194205

195206
# trilinear
196207
# linearly upsamples first three dims of 5D array
197-
function upsample_linear_kernel!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
208+
function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
198209
size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
199210
in_w, in_h, in_d, channels, batches = size(input)
200211
# treat batch and channel dimension as one for better parallelization granularity
@@ -254,7 +265,7 @@ function ∇upsample_linear(Δ::AbstractArray{T,N}; size::NTuple{<:Any,Integer})
254265
end
255266

256267
# linear
257-
function upsample_linear_kernel!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
268+
function upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
258269
size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
259270
in_w, channels, batches = size(dx)
260271

@@ -280,7 +291,7 @@ function ∇upsample_linear_kernel!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,
280291
end
281292

282293
# bilinear
283-
function upsample_linear_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
294+
function upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
284295
size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
285296
in_w, in_h, channels, batches = size(dx)
286297

@@ -312,7 +323,7 @@ function ∇upsample_linear_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,
312323
end
313324

314325
# trilinear
315-
function upsample_linear_kernel!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
326+
function upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
316327
size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))")
317328
in_w, in_h, in_d, channels, batches = size(dx)
318329
# treat batch and channel dimension as one for better parallelization granularity

0 commit comments

Comments
 (0)