Skip to content

Commit 5099132

Browse files
Pietro Vertechitkf
andauthored
use axes correctly in collection (#110)
* Support collecting to offset arrays * use axes correctly in collection * update tests Co-authored-by: Takafumi Arakaki <[email protected]>
1 parent b8e5914 commit 5099132

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

src/collect.jl

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
default_array(::Type{S}, d) where {S} = Array{S}(undef, d)
1+
default_array(::Type{S}, d::NTuple{N, Any}) where {S, N} = similar(Array{S, N}, d)
22

33
struct StructArrayInitializer{F, G}
44
unwrap::F
@@ -21,14 +21,10 @@ ArrayInitializer(unwrap = t->false) = ArrayInitializer(unwrap, default_array)
2121

2222
(s::ArrayInitializer)(S, d) = s.unwrap(S) ? buildfromschema(typ -> s(typ, d), S) : s.default_array(S, d)
2323

24-
_reshape(v, itr) = _reshape(v, itr, Base.IteratorSize(itr))
25-
_reshape(v, itr, ::Base.HasShape) = reshapestructarray(v, axes(itr))
26-
_reshape(v, itr, ::Union{Base.HasLength, Base.SizeUnknown}) = v
27-
28-
# temporary workaround before it gets easier to support reshape with offset axis
29-
reshapestructarray(v::AbstractArray, d) = reshape(v, d)
30-
reshapestructarray(v::StructArray{T}, d) where {T} =
31-
StructArray{T}(map(x -> reshapestructarray(x, d), fieldarrays(v)))
24+
_axes(itr) = _axes(itr, Base.IteratorSize(itr))
25+
_axes(itr, ::Base.SizeUnknown) = nothing
26+
_axes(itr, ::Base.HasLength) = (Base.OneTo(length(itr)),)
27+
_axes(itr, ::Base.HasShape) = axes(itr)
3228

3329
"""
3430
`collect_structarray(itr; initializer = default_initializer)`
@@ -39,31 +35,31 @@ and size `d`. By default `initializer` returns a `StructArray` of `Array` but cu
3935
may be used.
4036
"""
4137
function collect_structarray(itr; initializer = default_initializer)
42-
len = Base.IteratorSize(itr) === Base.SizeUnknown() ? 1 : length(itr)
38+
ax = _axes(itr)
4339
elem = iterate(itr)
44-
_collect_structarray(itr, elem, len; initializer = initializer)
40+
_collect_structarray(itr, elem, ax; initializer = initializer)
4541
end
4642

47-
function _collect_structarray(itr::T, ::Nothing, len; initializer = default_initializer) where {T}
43+
function _collect_structarray(itr::T, ::Nothing, ax; initializer = default_initializer) where {T}
4844
S = Core.Compiler.return_type(first, Tuple{T})
49-
res = initializer(S, (0,))
50-
_reshape(res, itr)
45+
return initializer(S, something(ax, (Base.OneTo(0),)))
5146
end
5247

53-
function _collect_structarray(itr, elem, len; initializer = default_initializer)
48+
function _collect_structarray(itr, elem, ax; initializer = default_initializer)
5449
el, st = elem
5550
S = typeof(el)
56-
dest = initializer(S, (len,))
57-
@inbounds dest[1] = el
58-
return _collect_structarray!(dest, itr, st, Base.IteratorSize(itr))
51+
dest = initializer(S, something(ax, (Base.OneTo(1),)))
52+
offs = first(LinearIndices(dest))
53+
@inbounds dest[offs] = el
54+
return _collect_structarray!(dest, itr, st, ax)
5955
end
6056

61-
function _collect_structarray!(dest, itr, st, ::Union{Base.HasShape, Base.HasLength})
62-
v = collect_to_structarray!(dest, itr, 2, st)
63-
return _reshape(v, itr)
57+
function _collect_structarray!(dest, itr, st, ax)
58+
offs = first(LinearIndices(dest)) + 1
59+
return collect_to_structarray!(dest, itr, offs, st)
6460
end
6561

66-
_collect_structarray!(dest, itr, st, ::Base.SizeUnknown) =
62+
_collect_structarray!(dest, itr, st, ::Nothing) =
6763
grow_to_structarray!(dest, itr, iterate(itr, st))
6864

6965
function collect_to_structarray!(dest::AbstractArray, itr, offs, st)
@@ -122,7 +118,7 @@ _widenstructarray(dest::AbstractArray, i, ::Type{T}) where {T} = _widenarray(des
122118
_widenarray(dest::AbstractArray{T}, i, ::Type{T}) where {T} = dest
123119
function _widenarray(dest::AbstractArray, i, ::Type{T}) where T
124120
new = similar(dest, T, length(dest))
125-
copyto!(new, 1, dest, 1, i-1)
121+
copyto!(new, firstindex(new), dest, firstindex(dest), i-1)
126122
new
127123
end
128124

test/runtests.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,23 @@ end
527527
@test sa isa StructArray
528528
@test axes(sa) == (-2:7,)
529529
@test sa.a == fill(1, -2:7)
530+
531+
zero_origin(T, d) = OffsetArray{T}(undef, map(r -> r .- 1, d))
532+
sa = collect_structarray(
533+
[(a = 1,), (a = 2,), (a = 3,)],
534+
initializer = StructArrays.StructArrayInitializer(t -> false, zero_origin),
535+
)
536+
@test sa isa StructArray
537+
@test collect(sa.a) == 1:3
538+
@test sa.a isa OffsetArray
539+
540+
sa = collect_structarray(
541+
(x for x in [(a = 1,), (a = 2,), (a = 3,)] if true),
542+
initializer = StructArrays.StructArrayInitializer(t -> false, zero_origin),
543+
)
544+
@test sa isa StructArray
545+
@test collect(sa.a) == 1:3
546+
@test sa.a isa OffsetArray
530547
end
531548

532549
@testset "hasfields" begin

0 commit comments

Comments
 (0)