@@ -25,3 +25,105 @@ function reduce_window(
25
25
base_dilations= ones (Int, N),
26
26
)[1 ]
27
27
end
28
+
29
+ function upsample_linear (
30
+ x:: AnyTracedRArray{T,3} , out_size:: Tuple{Int} , rwidth, align_corners:: Bool
31
+ ) where {T}
32
+ W, _, _ = size (x)
33
+
34
+ out_idxs = Ops. iota (Int32, [out_size[1 ]]; iota_dimension= 1 )
35
+ iw0, iw1, w0_λ, w1_λ = source_idx_and_λ (rwidth, out_idxs, align_corners, W)
36
+
37
+ x0 = x[iw0, :, :]
38
+ x1 = x[iw1, :, :]
39
+
40
+ return w0_λ .* x0 .+ w1_λ .* x1
41
+ end
42
+
43
+ function upsample_linear (
44
+ x:: AnyTracedRArray{T,4} , out_size:: Tuple{Int,Int} , rwidth, rheight, align_corners:: Bool
45
+ ) where {T}
46
+ W, H, _, _ = size (x)
47
+
48
+ out_width = Ops. iota (Int32, [out_size[1 ]]; iota_dimension= 1 )
49
+ out_height = Ops. iota (Int32, [out_size[2 ]]; iota_dimension= 1 )
50
+
51
+ iw0, iw1, w0_λ, w1_λ = source_idx_and_λ (rwidth, out_width, align_corners, W)
52
+ ih0, ih1, h0_λ, h1_λ = source_idx_and_λ (rheight, out_height, align_corners, H)
53
+
54
+ w0_λ, w1_λ = reshape (w0_λ, (:, 1 , 1 , 1 )), reshape (w1_λ, (:, 1 , 1 , 1 ))
55
+ h0_λ, h1_λ = reshape (h0_λ, (1 , :, 1 , 1 )), reshape (h1_λ, (1 , :, 1 , 1 ))
56
+
57
+ x00 = x[iw0, ih0, :, :]
58
+ x10 = x[iw1, ih0, :, :]
59
+ x01 = x[iw0, ih1, :, :]
60
+ x11 = x[iw1, ih1, :, :]
61
+
62
+ return h0_λ .* (w0_λ .* x00 .+ w1_λ .* x10) .+ h1_λ .* (w0_λ .* x01 .+ w1_λ .* x11)
63
+ end
64
+
65
+ function upsample_linear (
66
+ x:: AnyTracedRArray{T,5} ,
67
+ out_size:: Tuple{Int,Int,Int} ,
68
+ rwidth,
69
+ rheight,
70
+ rdepth,
71
+ align_corners:: Bool ,
72
+ ) where {T}
73
+ W, H, D, _, _ = size (x)
74
+
75
+ out_width = Ops. iota (Int32, [out_size[1 ]]; iota_dimension= 1 )
76
+ out_height = Ops. iota (Int32, [out_size[2 ]]; iota_dimension= 1 )
77
+ out_depth = Ops. iota (Int32, [out_size[3 ]]; iota_dimension= 1 )
78
+
79
+ iw0, iw1, w0_λ, w1_λ = source_idx_and_λ (rwidth, out_width, align_corners, W)
80
+ ih0, ih1, h0_λ, h1_λ = source_idx_and_λ (rheight, out_height, align_corners, H)
81
+ id0, id1, d0_λ, d1_λ = source_idx_and_λ (rdepth, out_depth, align_corners, D)
82
+
83
+ w0_λ = reshape (w0_λ, (:, 1 , 1 , 1 ))
84
+ w1_λ = reshape (w1_λ, (:, 1 , 1 , 1 ))
85
+ h0_λ = reshape (h0_λ, (1 , :, 1 , 1 ))
86
+ h1_λ = reshape (h1_λ, (1 , :, 1 , 1 ))
87
+ d0_λ = reshape (d0_λ, (1 , 1 , :, 1 ))
88
+ d1_λ = reshape (d1_λ, (1 , 1 , :, 1 ))
89
+
90
+ x000 = x[iw0, ih0, id0, :, :]
91
+ x100 = x[iw1, ih0, id0, :, :]
92
+ x010 = x[iw0, ih1, id0, :, :]
93
+ x110 = x[iw1, ih1, id0, :, :]
94
+
95
+ x001 = x[iw0, ih0, id1, :, :]
96
+ x101 = x[iw1, ih0, id1, :, :]
97
+ x011 = x[iw0, ih1, id1, :, :]
98
+ x111 = x[iw1, ih1, id1, :, :]
99
+
100
+ return (
101
+ (
102
+ d0_λ .* (
103
+ h0_λ .* (w0_λ .* x000 .+ w1_λ .* x100) .+
104
+ h1_λ .* (w0_λ .* x010 .+ w1_λ .* x110)
105
+ )
106
+ ) .+ (
107
+ d1_λ .* (
108
+ h0_λ .* (w0_λ .* x001 .+ w1_λ .* x101) .+
109
+ h1_λ .* (w0_λ .* x011 .+ w1_λ .* x111)
110
+ )
111
+ )
112
+ )
113
+ end
114
+
115
+ @inline function source_idx_and_λ (
116
+ ratio:: T , out_idx:: AbstractVector , align:: Bool , in_width:: Int
117
+ ) where {T}
118
+ real_index = ifelse (
119
+ align, ratio .* out_idx, max .(zero (T), ratio .* (out_idx .+ T (0.5 )) .- T (0.5 ))
120
+ )
121
+
122
+ iw0 = Base. Fix1 (floor, Int).(real_index)
123
+ offset = ifelse .(iw0 .< in_width - 1 , 1 , 0 )
124
+ iw1 = iw0 .+ offset .+ 1
125
+
126
+ w1lambda = real_index .- iw0
127
+ w0lambda = one (T) .- w1lambda
128
+ return iw0 .+ 1 , iw1, w0lambda, w1lambda
129
+ end
0 commit comments