Skip to content

Commit aec9957

Browse files
feat: upsampling functions (#1387)
* test: pixel_shuffle * feat: upsample linear * Update test/nn/nnlib.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * feat: upsample bilinear & trilinear * test: upsampling functions --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 565a1a2 commit aec9957

File tree

3 files changed

+162
-1
lines changed

3 files changed

+162
-1
lines changed

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,3 +474,16 @@ function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::I
474474
slice_sizes=Int64[size(src)[1:n_dims]..., ones(Int64, ndims(src) - n_dims)...],
475475
)
476476
end
477+
478+
function NNlib.upsample_linear_kernel!(
479+
y::AnyTracedRArray{T,N}, x::AnyTracedRArray{T,N}; align_corners::Bool=true
480+
) where {T,N}
481+
wT = real(Reactant.unwrapped_eltype(T))
482+
ratios = if align_corners
483+
ntuple(i -> wT((size(x, i) - 1) / (size(y, i) - 1)), N - 2)
484+
else
485+
ntuple(i -> wT(size(x, i) / size(y, i)), N - 2)
486+
end
487+
copyto!(y, upsample_linear(x, size(y)[1:(end - 2)], ratios..., align_corners))
488+
return y
489+
end

ext/ReactantNNlibExt/Ops.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,105 @@ function reduce_window(
2525
base_dilations=ones(Int, N),
2626
)[1]
2727
end
28+
29+
function upsample_linear(
30+
x::AnyTracedRArray{T,3}, out_size::Tuple{Int}, rwidth, align_corners::Bool
31+
) where {T}
32+
W, _, _ = size(x)
33+
34+
out_idxs = Ops.iota(Int32, [out_size[1]]; iota_dimension=1)
35+
iw0, iw1, w0_λ, w1_λ = source_idx_and_λ(rwidth, out_idxs, align_corners, W)
36+
37+
x0 = x[iw0, :, :]
38+
x1 = x[iw1, :, :]
39+
40+
return w0_λ .* x0 .+ w1_λ .* x1
41+
end
42+
43+
function upsample_linear(
44+
x::AnyTracedRArray{T,4}, out_size::Tuple{Int,Int}, rwidth, rheight, align_corners::Bool
45+
) where {T}
46+
W, H, _, _ = size(x)
47+
48+
out_width = Ops.iota(Int32, [out_size[1]]; iota_dimension=1)
49+
out_height = Ops.iota(Int32, [out_size[2]]; iota_dimension=1)
50+
51+
iw0, iw1, w0_λ, w1_λ = source_idx_and_λ(rwidth, out_width, align_corners, W)
52+
ih0, ih1, h0_λ, h1_λ = source_idx_and_λ(rheight, out_height, align_corners, H)
53+
54+
w0_λ, w1_λ = reshape(w0_λ, (:, 1, 1, 1)), reshape(w1_λ, (:, 1, 1, 1))
55+
h0_λ, h1_λ = reshape(h0_λ, (1, :, 1, 1)), reshape(h1_λ, (1, :, 1, 1))
56+
57+
x00 = x[iw0, ih0, :, :]
58+
x10 = x[iw1, ih0, :, :]
59+
x01 = x[iw0, ih1, :, :]
60+
x11 = x[iw1, ih1, :, :]
61+
62+
return h0_λ .* (w0_λ .* x00 .+ w1_λ .* x10) .+ h1_λ .* (w0_λ .* x01 .+ w1_λ .* x11)
63+
end
64+
65+
function upsample_linear(
66+
x::AnyTracedRArray{T,5},
67+
out_size::Tuple{Int,Int,Int},
68+
rwidth,
69+
rheight,
70+
rdepth,
71+
align_corners::Bool,
72+
) where {T}
73+
W, H, D, _, _ = size(x)
74+
75+
out_width = Ops.iota(Int32, [out_size[1]]; iota_dimension=1)
76+
out_height = Ops.iota(Int32, [out_size[2]]; iota_dimension=1)
77+
out_depth = Ops.iota(Int32, [out_size[3]]; iota_dimension=1)
78+
79+
iw0, iw1, w0_λ, w1_λ = source_idx_and_λ(rwidth, out_width, align_corners, W)
80+
ih0, ih1, h0_λ, h1_λ = source_idx_and_λ(rheight, out_height, align_corners, H)
81+
id0, id1, d0_λ, d1_λ = source_idx_and_λ(rdepth, out_depth, align_corners, D)
82+
83+
w0_λ = reshape(w0_λ, (:, 1, 1, 1))
84+
w1_λ = reshape(w1_λ, (:, 1, 1, 1))
85+
h0_λ = reshape(h0_λ, (1, :, 1, 1))
86+
h1_λ = reshape(h1_λ, (1, :, 1, 1))
87+
d0_λ = reshape(d0_λ, (1, 1, :, 1))
88+
d1_λ = reshape(d1_λ, (1, 1, :, 1))
89+
90+
x000 = x[iw0, ih0, id0, :, :]
91+
x100 = x[iw1, ih0, id0, :, :]
92+
x010 = x[iw0, ih1, id0, :, :]
93+
x110 = x[iw1, ih1, id0, :, :]
94+
95+
x001 = x[iw0, ih0, id1, :, :]
96+
x101 = x[iw1, ih0, id1, :, :]
97+
x011 = x[iw0, ih1, id1, :, :]
98+
x111 = x[iw1, ih1, id1, :, :]
99+
100+
return (
101+
(
102+
d0_λ .* (
103+
h0_λ .* (w0_λ .* x000 .+ w1_λ .* x100) .+
104+
h1_λ .* (w0_λ .* x010 .+ w1_λ .* x110)
105+
)
106+
) .+ (
107+
d1_λ .* (
108+
h0_λ .* (w0_λ .* x001 .+ w1_λ .* x101) .+
109+
h1_λ .* (w0_λ .* x011 .+ w1_λ .* x111)
110+
)
111+
)
112+
)
113+
end
114+
115+
@inline function source_idx_and_λ(
116+
ratio::T, out_idx::AbstractVector, align::Bool, in_width::Int
117+
) where {T}
118+
real_index = ifelse(
119+
align, ratio .* out_idx, max.(zero(T), ratio .* (out_idx .+ T(0.5)) .- T(0.5))
120+
)
121+
122+
iw0 = Base.Fix1(floor, Int).(real_index)
123+
offset = ifelse.(iw0 .< in_width - 1, 1, 0)
124+
iw1 = iw0 .+ offset .+ 1
125+
126+
w1lambda = real_index .- iw0
127+
w0lambda = one(T) .- w1lambda
128+
return iw0 .+ 1, iw1, w0lambda, w1lambda
129+
end

test/nn/nnlib.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,53 @@ end
416416
x = randn(Float32, 4, 4, 3, 2)
417417
x_ra = Reactant.to_rarray(x)
418418

419-
@test @jit(NNlib.upsample_nearest(x_ra, (2, 2))) NNlib.upsample_nearest(x, (2, 2))
419+
@testset "Nearest" begin
420+
@test @jit(NNlib.upsample_nearest(x_ra, (2, 2))) NNlib.upsample_nearest(x, (2, 2))
421+
end
422+
423+
@testset "Linear" begin
424+
x = randn(Float32, 4, 3, 2)
425+
x_ra = Reactant.to_rarray(x)
426+
427+
@test @jit(NNlib.upsample_linear(x_ra, (2,))) NNlib.upsample_linear(x, (2,))
428+
429+
@test @jit(NNlib.upsample_linear(x_ra, (2,); align_corners=false))
430+
NNlib.upsample_linear(x, (2,); align_corners=false)
431+
end
432+
433+
@testset "Bi-Linear" begin
434+
x = randn(Float32, 4, 4, 3, 2)
435+
x_ra = Reactant.to_rarray(x)
436+
437+
@test @jit(NNlib.upsample_bilinear(x_ra, (2, 2)))
438+
NNlib.upsample_bilinear(x, (2, 2))
439+
440+
@test @jit(NNlib.upsample_bilinear(x_ra, (2, 2); align_corners=false))
441+
NNlib.upsample_bilinear(x, (2, 2); align_corners=false)
442+
end
443+
444+
@testset "Tri-Linear" begin
445+
x = randn(Float32, 4, 4, 4, 3, 2)
446+
x_ra = Reactant.to_rarray(x)
447+
448+
@test @jit(NNlib.upsample_trilinear(x_ra, (2, 2, 2)))
449+
NNlib.upsample_trilinear(x, (2, 2, 2))
450+
451+
@test @jit(NNlib.upsample_trilinear(x_ra, (2, 2, 2); align_corners=false))
452+
NNlib.upsample_trilinear(x, (2, 2, 2); align_corners=false)
453+
end
454+
end
455+
456+
@testset "Pixel shuffle" begin
457+
x = [10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
458+
x_ra = Reactant.to_rarray(x)
459+
460+
@test @jit(NNlib.pixel_shuffle(x_ra, 2)) NNlib.pixel_shuffle(x, 2)
461+
462+
y = [i + channel / 10 for i in 1:3, channel in 1:6, batch in 1:1]
463+
y_ra = Reactant.to_rarray(y)
464+
465+
@test @jit(NNlib.pixel_shuffle(y_ra, 2)) NNlib.pixel_shuffle(y, 2)
420466
end
421467

422468
@testset "softmax/logsoftmax reshaped input" begin

0 commit comments

Comments
 (0)