Skip to content

Commit 211e6a6

Browse files
authored
Merge pull request #8 from mcabbott/ragged
Ragged stack
2 parents 48b7680 + e1e2b6d commit 211e6a6

File tree

2 files changed

+151
-21
lines changed

2 files changed

+151
-21
lines changed

src/LazyStack.jl

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
module LazyStack
22

3-
export stack
3+
export stack, rstack
44

55
#===== Tuples =====#
66

77
ndims(A) = Base.ndims(A)
88
ndims(::Tuple) = 1
99
ndims(::NamedTuple) = 1
1010

11+
axes(A) = Base.axes(A)
12+
axes(A, d) = Base.axes(A, d)
13+
if VERSION < v"1.1" # because, on Julia 1.0, axes((1,2)) === Base.OneTo(2)
14+
axes(t::Tuple) = tuple(Base.axes(t))
15+
end
16+
axes(nt::NamedTuple) = tuple(Base.OneTo(length(nt)))
17+
1118
size(A) = Base.size(A)
1219
size(t::Tuple) = tuple(length(t))
1320
size(t::NamedTuple) = tuple(length(t))
@@ -80,9 +87,6 @@ Base.size(x::Stacked) = (size(first(x.slices))..., size(x.slices)...)
8087
Base.size(x::Stacked{T,N,<:Tuple}) where {T,N} = (size(first(x.slices))..., length(x.slices))
8188

8289
Base.axes(x::Stacked) = (axes(first(x.slices))..., axes(x.slices)...)
83-
if VERSION < v"1.1" # axes((1:9, 1:9)) == Base.OneTo(2) # on Julia 1.0
84-
Base.axes(x::Stacked{T,N,<:Tuple}) where {T,N} = (axes(first(x.slices))..., axes(x.slices))
85-
end
8690

8791
Base.parent(x::Stacked) = x.slices
8892

@@ -368,4 +372,100 @@ function storage_type(x::AbstractArray)
368372
typeof(x) === typeof(p) ? typeof(x) : storage_type(p)
369373
end
370374

375+
#===== Ragged =====#
376+
377+
"""
378+
rstack(arrays; fill=0)
379+
380+
Ragged `stack`, which allows slices of varying size, and fills the gaps with zero
381+
or the given `fill`. Always returns an `Array`.
382+
383+
```
384+
julia> rstack(1:n for n in 1:5)
385+
5×5 Array{Int64,2}:
386+
1 1 1 1 1
387+
0 2 2 2 2
388+
0 0 3 3 3
389+
0 0 0 4 4
390+
0 0 0 0 5
391+
392+
julia> rstack([[1,2,3], [10,20.0]], fill=missing)
393+
3×2 Array{Union{Missing, Float64},2}:
394+
1.0 10.0
395+
2.0 20.0
396+
3.0 missing
397+
398+
julia> using OffsetArrays
399+
400+
julia> rstack(1:3, OffsetArray([2.0,2.1,2.2], -1), OffsetArray([3.2,3.3,3.4], +1))
401+
5×3 OffsetArray(::Array{Real,2}, 0:4, 1:3) with eltype Real with indices 0:4×1:3:
402+
0 2.0 0
403+
1 2.1 0
404+
2 2.2 3.2
405+
3 0 3.3
406+
0 0 3.4
407+
```
408+
409+
"""
410+
rstack(x::AbstractArray, ys::AbstractArray...; kw...) = rstack((x, ys...); kw...)
411+
rstack(g::Base.Generator; kw...) = rstack(collect(g); kw...)
412+
rstack(f::Function, ABC...; kw...) = rstack(map(f, ABC...); kw...)
413+
rstack(list::AbstractArray{<:AbstractArray}; fill=zero(eltype(first(list)))) = rstack_iter(list; fill=fill)
414+
rstack(list::Tuple{Vararg{<:AbstractArray}}; fill=zero(eltype(first(list)))) = rstack_iter(list; fill=fill)
415+
416+
function rstack_iter(list; fill)
417+
T = mapreduce(eltype, Base.promote_typejoin, list, init=typeof(fill))
418+
# T = mapreduce(eltype, Base.promote_type, list, init=typeof(fill))
419+
N = maximum(ndims, list)
420+
ax = ntuple(N) do d
421+
hi = maximum(x -> last(axes(x,d)), list)
422+
if all(x -> axes(x,d) isa Base.OneTo, list)
423+
Base.OneTo(hi)
424+
else
425+
lo = minimum(x -> first(axes(x,d)), list)
426+
lo:hi
427+
end
428+
end
429+
arr = Array{T}(undef, map(length, ax)..., size(list)...)
430+
fill!(arr, fill)
431+
out = if ax isa Tuple{Vararg{Base.OneTo}}
432+
arr
433+
else
434+
OffsetArray(arr, (ax..., axes(list)...))
435+
end
436+
z = rstack_copyto!(out, list, Val(N))
437+
438+
rewrap_names(z, first(list)) # now I want to separate names & offsets again!
439+
end
440+
441+
function rstack_copyto!(out, list, ::Val{N}) where {N}
442+
for i in tupleindices(list)
443+
item = list[i...]
444+
o = ntuple(_->1, N - ndims(item))
445+
out[CartesianIndices(axes(item)), o..., i...] .= item
446+
447+
# https://github.com/JuliaArrays/OffsetArrays.jl/issues/100
448+
# view(out, axes(item)..., o..., i...) .= item
449+
450+
# for I in CartesianIndices(item)
451+
# out[Tuple(I)..., o..., i...] = item[I]
452+
# end
453+
end
454+
out
455+
end
456+
457+
tupleindices(t::Tuple) = ((i,) for i in 1:length(t))
458+
tupleindices(A::AbstractArray) = (Tuple(I) for I in CartesianIndices(A))
459+
460+
rewrap_names(A, a) = A
461+
function rewrap_names(A, a::NamedDimsArray{L}) where {L}
462+
B = rewrap_names(A, parent(a))
463+
ensure_named(B, (L..., ntuple(_ -> :_, ndims(A) - ndims(a))...))
464+
end
465+
function rstack(s::Symbol, args...)
466+
data = rstack(args...)
467+
name_last = ntuple(d -> d==ndims(data) ? s : :_, ndims(data))
468+
ensure_named(data, name_last)
469+
end
470+
371471
end # module

test/runtests.jl

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using OffsetArrays, NamedDims
1818
@test stack(v34)[1,1,1] == v34[1][1] # trailing dims
1919
@test stack(v34) * ones(4) hcat(v34...) * ones(4) # issue #6
2020
@test stack(v34) * ones(4,2) hcat(v34...) * ones(4,2)
21+
@test axes(stack(v34)) === axes(stack(v34...)) === axes(stack(v34[i] for i in 1:4))
2122

2223
end
2324
@testset "tuples" begin
@@ -132,6 +133,9 @@ end
132133
@test axes(stack(oout)) == (1:3, 11:14)
133134
@test axes(copy(stack(oout))) == (1:3, 11:14)
134135

136+
oboth = OffsetArray(oin, 11:14)
137+
@test axes(stack(oboth)) == (3:5, 11:14)
138+
135139
ogen = (OffsetArray([3,4,5], 3:5) for i in 1:4)
136140
@test axes(stack(ogen)) == (3:5, 1:4)
137141

@@ -178,23 +182,6 @@ end
178182

179183
@test_throws DimensionMismatch push!(stack([rand(2)]), rand(3))
180184

181-
end
182-
@info "loading Zygote"
183-
using Zygote
184-
@testset "zygote" begin
185-
186-
@test Zygote.gradient((x,y) -> sum(stack(x,y)), ones(2), ones(2)) == ([1,1], [1,1])
187-
@test Zygote.gradient((x,y) -> sum(stack([x,y])), ones(2), ones(2)) == ([1,1], [1,1])
188-
189-
f399(x) = sum(stack(x) * sum(x))
190-
f399c(x) = sum(collect(stack(x)) * sum(x))
191-
@test Zygote.gradient(f399, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
192-
@test Zygote.gradient(f399c, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
193-
ftup(x) = sum(stack(x...) * sum(x))
194-
ftupc(x) = sum(collect(stack(x...)) * sum(x))
195-
@test Zygote.gradient(ftup, (ones(2), ones(2))) == (([4,4], [4,4]),)
196-
@test Zygote.gradient(ftupc, (ones(2), ones(2))) == (([4,4], [4,4]),)
197-
198185
end
199186
@testset "readme" begin
200187

@@ -218,3 +205,46 @@ end
218205
@test LazyStack.vstack(g234) == reduce(vcat, collect(g234))
219206

220207
end
208+
@testset "ragged" begin
209+
210+
@test rstack([1,2], 1:3) == [1 1; 2 2; 0 3]
211+
@test rstack([[1,2], 1:3], fill=99) == [1 1; 2 2; 99 3]
212+
213+
@test rstack(1:2, OffsetArray([2,3], +1)) == [1 0; 2 2; 0 3]
214+
@test rstack(1:2, OffsetArray([0.1,1], -1)) == OffsetArray([0 0.1; 1 1.0; 2 0],-1,0)
215+
216+
@test dimnames(rstack(:b, NamedDimsArray(1:2, :a), OffsetArray([2,3], +1))) == (:a, :b)
217+
218+
end
219+
@testset "tuple functions" begin
220+
221+
@test LazyStack.ndims([1,2]) == 1
222+
@test LazyStack.ndims((1,2)) == 1
223+
@test LazyStack.ndims((a=1,b=2)) == 1
224+
225+
@test LazyStack.size([1,2]) == (2,)
226+
@test LazyStack.size((1,2)) == (2,)
227+
@test LazyStack.size((a=1,b=2)) == (2,)
228+
229+
@test LazyStack.axes([1,2]) == (1:2,)
230+
@test LazyStack.axes((1,2)) == (1:2,)
231+
@test LazyStack.axes((a=1,b=2)) == (1:2,)
232+
233+
end
234+
@info "loading Zygote"
235+
using Zygote
236+
@testset "zygote" begin
237+
238+
@test Zygote.gradient((x,y) -> sum(stack(x,y)), ones(2), ones(2)) == ([1,1], [1,1])
239+
@test Zygote.gradient((x,y) -> sum(stack([x,y])), ones(2), ones(2)) == ([1,1], [1,1])
240+
241+
f399(x) = sum(stack(x) * sum(x))
242+
f399c(x) = sum(collect(stack(x)) * sum(x))
243+
@test Zygote.gradient(f399, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
244+
@test Zygote.gradient(f399c, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
245+
ftup(x) = sum(stack(x...) * sum(x))
246+
ftupc(x) = sum(collect(stack(x...)) * sum(x))
247+
@test Zygote.gradient(ftup, (ones(2), ones(2))) == (([4,4], [4,4]),)
248+
@test Zygote.gradient(ftupc, (ones(2), ones(2))) == (([4,4], [4,4]),)
249+
250+
end

0 commit comments

Comments
 (0)