Skip to content

Commit 2d9c6b2

Browse files
author
Michael Abbott
committed
rstack_iter, names, CartesianIndices
1 parent fce4cf6 commit 2d9c6b2

File tree

2 files changed

+49
-24
lines changed

2 files changed

+49
-24
lines changed

src/LazyStack.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,11 @@ julia> rstack(1:3, OffsetArray([2.0,2.1,2.2], -1), OffsetArray([3.2,3.3,3.4], +1
405405
"""
406406
rstack(x::AbstractArray, ys::AbstractArray...; kw...) = rstack((x, ys...); kw...)
407407
rstack(g::Base.Generator; kw...) = rstack(collect(g); kw...)
408+
rstack(f::Function, ABC...; kw...) = rstack(map(f, ABC...); kw...)
409+
rstack(list::AbstractArray{<:AbstractArray}; fill=zero(eltype(first(list)))) = rstack_iter(list; fill=fill)
410+
rstack(list::Tuple{Vararg{<:AbstractArray}}; fill=zero(eltype(first(list)))) = rstack_iter(list; fill=fill)
408411

409-
function rstack(list::Union{AbstractArray{<:AbstractArray}, Tuple{Vararg{<:AbstractArray}}}; fill=zero(eltype(first(list))))
412+
function rstack_iter(list; fill)
410413
T = mapreduce(eltype, Base.promote_typejoin, list, init=typeof(fill))
411414
# T = mapreduce(eltype, Base.promote_type, list, init=typeof(fill))
412415
N = maximum(ndims, list)
@@ -426,19 +429,39 @@ function rstack(list::Union{AbstractArray{<:AbstractArray}, Tuple{Vararg{<:Abstr
426429
else
427430
OffsetArray(arr, (ax..., axes(list)...))
428431
end
432+
z = rstack_copyto!(out, list, Val(N))
433+
434+
rewrap_names(z, first(list)) # now I want to separate names & offsets again!
435+
end
436+
437+
function rstack_copyto!(out, list, ::Val{N}) where {N}
429438
for i in tupleindices(list)
430439
item = list[i...]
431440
o = ntuple(_->1, N - ndims(item))
432-
# view(out, axes(item)..., i...) .= item
433-
for I in CartesianIndices(item)
434-
out[Tuple(I)..., o..., i...] = item[I]
435-
end
441+
out[CartesianIndices(axes(item)), o..., i...] .= item
442+
443+
# https://github.com/JuliaArrays/OffsetArrays.jl/issues/100
444+
# view(out, axes(item)..., o..., i...) .= item
445+
446+
# for I in CartesianIndices(item)
447+
# out[Tuple(I)..., o..., i...] = item[I]
448+
# end
436449
end
437450
out
438451
end
439452

440453
tupleindices(t::Tuple) = ((i,) for i in 1:length(t))
441454
tupleindices(A::AbstractArray) = (Tuple(I) for I in CartesianIndices(A))
442455

456+
rewrap_names(A, a) = A
457+
function rewrap_names(A, a::NamedDimsArray{L}) where {L}
458+
B = rewrap_names(A, parent(a))
459+
ensure_named(B, (L..., ntuple(_ -> :_, ndims(A) - ndims(a))...))
460+
end
461+
function rstack(s::Symbol, args...)
462+
data = rstack(args...)
463+
name_last = ntuple(d -> d==ndims(data) ? s : :_, ndims(data))
464+
ensure_named(data, name_last)
465+
end
443466

444467
end # module

test/runtests.jl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -178,23 +178,6 @@ end
178178

179179
@test_throws DimensionMismatch push!(stack([rand(2)]), rand(3))
180180

181-
end
182-
@info "loading Zygote"
183-
using Zygote
184-
@testset "zygote" begin
185-
186-
@test Zygote.gradient((x,y) -> sum(stack(x,y)), ones(2), ones(2)) == ([1,1], [1,1])
187-
@test Zygote.gradient((x,y) -> sum(stack([x,y])), ones(2), ones(2)) == ([1,1], [1,1])
188-
189-
f399(x) = sum(stack(x) * sum(x))
190-
f399c(x) = sum(collect(stack(x)) * sum(x))
191-
@test Zygote.gradient(f399, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
192-
@test Zygote.gradient(f399c, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
193-
ftup(x) = sum(stack(x...) * sum(x))
194-
ftupc(x) = sum(collect(stack(x...)) * sum(x))
195-
@test Zygote.gradient(ftup, (ones(2), ones(2))) == (([4,4], [4,4]),)
196-
@test Zygote.gradient(ftupc, (ones(2), ones(2))) == (([4,4], [4,4]),)
197-
198181
end
199182
@testset "readme" begin
200183

@@ -223,7 +206,26 @@ end
223206
@test rstack([1,2], 1:3) == [1 1; 2 2; 0 3]
224207
@test rstack([[1,2], 1:3], fill=99) == [1 1; 2 2; 99 3]
225208

226-
@test rstack(1:2, OffsetArray([2,3], 2:3)) == [1 0; 2 2; 0 3]
227-
@test rstack(1:2, OffsetArray([0.1,1], 0:1)) == OffsetArray([0 0.1; 1 1.0; 2 0],-1,0)
209+
@test rstack(1:2, OffsetArray([2,3], +1)) == [1 0; 2 2; 0 3]
210+
@test rstack(1:2, OffsetArray([0.1,1], -1)) == OffsetArray([0 0.1; 1 1.0; 2 0],-1,0)
211+
212+
@test dimnames(rstack(:b, NamedDimsArray(1:2, :a), OffsetArray([2,3], +1))) == (:a, :b)
213+
214+
end
215+
@info "loading Zygote"
216+
using Zygote
217+
@testset "zygote" begin
218+
219+
@test Zygote.gradient((x,y) -> sum(stack(x,y)), ones(2), ones(2)) == ([1,1], [1,1])
220+
@test Zygote.gradient((x,y) -> sum(stack([x,y])), ones(2), ones(2)) == ([1,1], [1,1])
221+
222+
f399(x) = sum(stack(x) * sum(x))
223+
f399c(x) = sum(collect(stack(x)) * sum(x))
224+
@test Zygote.gradient(f399, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
225+
@test Zygote.gradient(f399c, [ones(2), ones(2)]) == ([[4,4], [4,4]],)
226+
ftup(x) = sum(stack(x...) * sum(x))
227+
ftupc(x) = sum(collect(stack(x...)) * sum(x))
228+
@test Zygote.gradient(ftup, (ones(2), ones(2))) == (([4,4], [4,4]),)
229+
@test Zygote.gradient(ftupc, (ones(2), ones(2))) == (([4,4], [4,4]),)
228230

229231
end

0 commit comments

Comments
 (0)