Skip to content

Commit ba74c7b

Browse files
improvements to stack
1 parent 125205b commit ba74c7b

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

src/utils.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
Return `x` reshaped into an array one dimensionality higher than `x`,
77
where `dims` indicates in which dimension `x` is extended.
8+
`dims` can be an integer between 1 and `ndims(x)+1`.
89
910
See also [`flatten`](@ref), [`stack`](@ref).
1011
@@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
3334
[1, 2] [3, 4] [5, 6]
3435
```
3536
"""
36-
function unsqueeze(x::AbstractArray; dims::Int)
37-
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), ndims(x) + 1)
37+
function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N}
38+
# @assert 1 <= dims <= N + 1
39+
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1)
3840
return reshape(x, sz)
3941
end
4042

@@ -59,9 +61,11 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,
5961
stack(xs; dims)
6062
6163
Concatenate the given array of arrays `xs` into a single array along the
62-
given dimension `dims`.
64+
given dimension `dims`. All arrays need to be of the same size.
65+
The number of dimension in the final arrays is one more than the number
66+
of dimensions in the input arrays.
6367
64-
See also [`stack`](@ref) and [`batch`](@ref).
68+
See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref).
6569
6670
# Examples
6771
@@ -98,7 +102,28 @@ julia> stack(xs, dims=3)
98102
6
99103
```
100104
"""
101-
stack(xs; dims::Int) = cat(unsqueeze.(xs; dims)...; dims)
105+
function stack(xs; dims::Int)
106+
N = ndims(xs[1])
107+
if dims <= N
108+
vs = unsqueeze.(xs; dims)
109+
else
110+
vs = xs
111+
end
112+
if dims == 1
113+
return reduce(vcat, vs)
114+
elseif dims === 2
115+
return reduce(hcat, vs)
116+
else
117+
return reduce((x, y) -> cat(x, y; dims=dims), vs)
118+
end
119+
end
120+
121+
function rrule(::typeof(stack), xs; dims::Int)
122+
function stack_pullback(Δ)
123+
return (NoTangent(), unstack(unthunk(Δ); dims=dims))
124+
end
125+
return stack(xs; dims=dims), stack_pullback
126+
end
102127

103128
"""
104129
unstack(xs; dims)

test/test_utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
11

2+
"""
3+
Test gradients through zygote.
4+
5+
# Arguments
6+
7+
- `f`: function to test
8+
- `xs`: inputs to `f`
9+
10+
# Keyword Arguments
11+
Keyword arguments are passed to `rrule`.
12+
13+
- `fkwargs`: keyword arguments to `f`
14+
"""
215
function test_zygote(f, xs...; kws...)
316
config = ZygoteRuleConfig()
417
test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad)

test/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
@test @inferred(unsqueeze(x; dims=4)) == reshape(x, 2, 3, 2, 1)
77

88
@test unsqueeze(dims=2)(x) == unsqueeze(x, dims=2)
9+
10+
@test_throws AssertionError unsqueeze(rand(2,2), dims=4)
911
end
1012

1113
@testset "stack and unstack" begin
@@ -19,6 +21,16 @@ end
1921
@test unstack(stacked_array, dims=2) == unstacked_array
2022
@test stack(unstacked_array, dims=2) == stacked_array
2123
@test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array
24+
25+
for d in (1,2,3)
26+
test_zygote(stack, [x,2x], fkwargs=(; dims=d), check_inferred=false)
27+
end
28+
29+
# Issue #121
30+
a = [[1] for i in 1:10000]
31+
@test size(stack(a, dims=1)) == (10000, 1)
32+
@test size(stack(a, dims=2)) == (1, 10000)
33+
@test size(stack(a, dims=3)) == (1, 1, 10000)
2234
end
2335

2436
@testset "batch and unbatch" begin

0 commit comments

Comments
 (0)