5
5
6
6
Return `x` reshaped into an array one dimensionality higher than `x`,
7
7
where `dims` indicates in which dimension `x` is extended.
8
+ `dims` can be an integer between 1 and `ndims(x)+1`.
8
9
9
10
See also [`flatten`](@ref), [`stack`](@ref).
10
11
@@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
33
34
[1, 2] [3, 4] [5, 6]
34
35
```
35
36
"""
36
- function unsqueeze (x:: AbstractArray ; dims:: Int )
37
- sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), ndims (x) + 1 )
37
+ function unsqueeze (x:: AbstractArray{T,N} ; dims:: Int ) where {T, N}
38
+ # @assert 1 <= dims <= N + 1
39
+ sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), N + 1 )
38
40
return reshape (x, sz)
39
41
end
40
42
@@ -59,9 +61,11 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,
59
61
stack(xs; dims)
60
62
61
63
Concatenate the given array of arrays `xs` into a single array along the
62
- given dimension `dims`.
64
+ given dimension `dims`. All arrays need to be of the same size.
65
+ The number of dimension in the final arrays is one more than the number
66
+ of dimensions in the input arrays.
63
67
64
- See also [`stack `](@ref) and [`batch`](@ref).
68
+ See also [`unsqueeze`](@ref), [`unstack `](@ref) and [`batch`](@ref).
65
69
66
70
# Examples
67
71
@@ -98,7 +102,28 @@ julia> stack(xs, dims=3)
98
102
6
99
103
```
100
104
"""
101
- stack (xs; dims:: Int ) = cat (unsqueeze .(xs; dims)... ; dims)
105
+ function stack (xs; dims:: Int )
106
+ N = ndims (xs[1 ])
107
+ if dims <= N
108
+ vs = unsqueeze .(xs; dims)
109
+ else
110
+ vs = xs
111
+ end
112
+ if dims == 1
113
+ return reduce (vcat, vs)
114
+ elseif dims === 2
115
+ return reduce (hcat, vs)
116
+ else
117
+ return reduce ((x, y) -> cat (x, y; dims= dims), vs)
118
+ end
119
+ end
120
+
121
+ function rrule (:: typeof (stack), xs; dims:: Int )
122
+ function stack_pullback (Δ)
123
+ return (NoTangent (), unstack (unthunk (Δ); dims= dims))
124
+ end
125
+ return stack (xs; dims= dims), stack_pullback
126
+ end
102
127
103
128
"""
104
129
unstack(xs; dims)
0 commit comments