Skip to content

Commit b8ed90b

Browse files
use Base definition of stack
1 parent 4192164 commit b8ed90b

File tree

5 files changed

+8
-71
lines changed

5 files changed

+8
-71
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.3.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
910
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1011
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"

src/MLUtils.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ import NNlib
2222
@traitdef IsTable{X}
2323
@traitimpl IsTable{X} <- Tables.istable(X)
2424

25-
25+
if VERSION < v"1.9.0-DEV.1163"
26+
import Compat: stack
27+
else
28+
import Base: stack
29+
end
2630
include("observation.jl")
2731
export numobs,
2832
getobs,
@@ -75,7 +79,7 @@ export batch,
7579
rand_like,
7680
randn_like,
7781
rpad_constant,
78-
stack,
82+
stack, # in Base since julia v1.9
7983
unbatch,
8084
unsqueeze,
8185
unstack,

src/deprecations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Deprecated in v0.2
2-
@deprecate stack(x, dims) stack(x; dims=dims)
32
@deprecate unstack(x, dims) unstack(x; dims=dims)
43
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
54
@deprecate unsqueeze(dims::Int) unsqueeze(dims=dims)

src/utils.jl

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -57,72 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)
5757

5858
Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")")
5959

60-
"""
61-
stack(xs; dims)
62-
63-
Concatenate the given array of arrays `xs` into a single array along the
64-
new dimension `dims`. All arrays need to be of the same size.
65-
66-
See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref).
67-
68-
# Examples
69-
70-
```jldoctest
71-
julia> xs = [[1, 2], [3, 4], [5, 6]]
72-
3-element Vector{Vector{Int64}}:
73-
[1, 2]
74-
[3, 4]
75-
[5, 6]
76-
77-
julia> stack(xs, dims=1)
78-
3×2 Matrix{Int64}:
79-
1 2
80-
3 4
81-
5 6
82-
83-
julia> stack(xs, dims=2)
84-
2×3 Matrix{Int64}:
85-
1 3 5
86-
2 4 6
87-
88-
julia> stack(xs, dims=3)
89-
2×1×3 Array{Int64, 3}:
90-
[:, :, 1] =
91-
1
92-
2
93-
94-
[:, :, 2] =
95-
3
96-
4
97-
98-
[:, :, 3] =
99-
5
100-
6
101-
```
102-
"""
103-
function stack(xs; dims::Int)
104-
N = ndims(xs[1])
105-
if dims <= N
106-
vs = unsqueeze.(xs; dims)
107-
else
108-
vs = xs
109-
end
110-
if dims == 1
111-
return reduce(vcat, vs)
112-
elseif dims === 2
113-
return reduce(hcat, vs)
114-
else
115-
return reduce((x, y) -> cat(x, y; dims=dims), vs)
116-
end
117-
end
118-
119-
function rrule(::typeof(stack), xs; dims::Int)
120-
function stack_pullback(Δ)
121-
return (NoTangent(), unstack(unthunk(Δ); dims=dims))
122-
end
123-
return stack(xs; dims=dims), stack_pullback
124-
end
125-
12660
"""
12761
unstack(xs; dims)
12862

test/utils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414
x = randn(3,3)
1515
stacked = stack([x, x], dims=2)
1616
@test size(stacked) == (3,2,3)
17-
@test_broken @inferred(stack([x, x], dims=2)) == stacked
17+
@test @inferred(stack([x, x], dims=2)) == stacked
1818

1919
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
2020
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@@ -30,7 +30,6 @@ end
3030
a = [[1] for i in 1:10000]
3131
@test size(stack(a, dims=1)) == (10000, 1)
3232
@test size(stack(a, dims=2)) == (1, 10000)
33-
@test size(stack(a, dims=3)) == (1, 1, 10000)
3433
end
3534

3635
@testset "batch and unbatch" begin

0 commit comments

Comments
 (0)