Skip to content

Commit f9aec37

Browse files
author
Max Freudenberg
committed
partly revert changes to stay compatible w NNlibCUDA
1 parent b564601 commit f9aec37

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/upsample.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,17 @@ function upsample_linear(x::AbstractArray{T,<:Any}; size) where T<:Integer
134134
return round.(T, res)
135135
end
136136

137+
# compatibility layer for old versions of NNlibCUDA
138+
# old versions overload upsample_linear_wcn, new versions overload upsample_linear_kernel
139+
upsample_linear_kernel!(y::AbstractArray{<:Any,3}, x::AbstractArray{<:Any,3}) = upsample_linear_wcn!(y,x)
140+
upsample_linear_kernel!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = upsample_bilinear_whcn!(y,x)
141+
upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) = upsample_trilinear_whdcn!(y,x)
142+
∇upsample_linear_kernel!(y::AbstractArray{<:Any,3}, x::AbstractArray{<:Any,3}) = ∇upsample_linear_wcn!(y,x)
143+
∇upsample_linear_kernel!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = ∇upsample_bilinear_whcn!(y,x)
144+
∇upsample_linear_kernel!(y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5}) = ∇upsample_trilinear_whdcn!(y,x)
145+
137146
# linearly upsamples first dim of 3D array
138-
function upsample_linear_kernel!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
147+
function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
139148
size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
140149
in_w, channels, batches = size(input)
141150
# treat batch and channel dimension as one for better parallelization granularity
@@ -161,7 +170,7 @@ end
161170

162171
# bilinear
163172
# linearly upsamples first two dims of 4D array
164-
function upsample_linear_kernel!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
173+
function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
165174
size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
166175
in_w, in_h, channels, batches = size(input)
167176
# treat batch and channel dimension as one for better parallelization granularity
@@ -194,7 +203,7 @@ end
194203

195204
# trilinear
196205
# linearly upsamples first three dims of 5D array
197-
function upsample_linear_kernel!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
206+
function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
198207
size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
199208
in_w, in_h, in_d, channels, batches = size(input)
200209
# treat batch and channel dimension as one for better parallelization granularity
@@ -254,7 +263,7 @@ function ∇upsample_linear(Δ::AbstractArray{T,N}; size::NTuple{<:Any,Integer})
254263
end
255264

256265
# linear
257-
function upsample_linear_kernel!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
266+
function upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
258267
size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
259268
in_w, channels, batches = size(dx)
260269

@@ -280,7 +289,7 @@ function ∇upsample_linear_kernel!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,
280289
end
281290

282291
# bilinear
283-
function upsample_linear_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
292+
function upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
284293
size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
285294
in_w, in_h, channels, batches = size(dx)
286295

@@ -312,7 +321,7 @@ function ∇upsample_linear_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,
312321
end
313322

314323
# trilinear
315-
function upsample_linear_kernel!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
324+
function upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
316325
size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))")
317326
in_w, in_h, in_d, channels, batches = size(dx)
318327
# treat batch and channel dimension as one for better parallelization granularity

0 commit comments

Comments
 (0)