Skip to content

Commit 2d5565b

Browse files
authored
Allow constructing ConcatDiskArray from mixed type arrays (#210)
* Improve concat diskarray * Test mixed type concatenation * Make cat of unchunked chunked
1 parent 664ec6b commit 2d5565b

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

src/cat.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
55
Joins multiple AbstractArrays or AbstractDiskArrays in lazy concatination.
66
"""
7-
struct ConcatDiskArray{T,N,P} <: AbstractDiskArray{T,N}
7+
struct ConcatDiskArray{T,N,P,C,HC} <: AbstractDiskArray{T,N}
88
parents::P
99
startinds::NTuple{N,Vector{Int}}
1010
size::NTuple{N,Int}
11+
chunks::C
12+
haschunks::HC
1113
end
12-
function ConcatDiskArray(arrays::AbstractArray{<:AbstractArray{T,N},M}) where {T,N,M}
14+
function ConcatDiskArray(arrays::AbstractArray{<:AbstractArray{<:Any,N},M}) where {N,M}
15+
T = mapreduce(eltype,promote_type, init = eltype(first(arrays)),arrays)
16+
1317
function othersize(x, id)
1418
return (x[1:(id - 1)]..., x[(id + 1):end]...)
1519
end
@@ -51,12 +55,13 @@ function ConcatDiskArray(arrays::AbstractArray{<:AbstractArray{T,N},M}) where {T
5155
startinds = map(first, si)
5256
sizes = map(last, si)
5357

54-
return ConcatDiskArray{T,D,typeof(arrays1)}(arrays1, startinds, sizes)
58+
chunks = concat_chunksize(D, arrays1)
59+
hc = Chunked(batchstrategy(chunks))
60+
61+
return ConcatDiskArray{T,D,typeof(arrays1),typeof(chunks),typeof(hc)}(arrays1, startinds, sizes, chunks, hc)
5562
end
5663
function ConcatDiskArray(arrays::AbstractArray)
5764
# Validate array eltype and dimensionality
58-
all(a -> eltype(a) == eltype(first(arrays)), arrays) ||
59-
error("Arrays don't have the same element type")
6065
all(a -> ndims(a) == ndims(first(arrays)), arrays) ||
6166
error("Arrays don't have the same dimensions")
6267
return error("Should not be reached")
@@ -98,11 +103,10 @@ function _concat_diskarray_block_io(f, a::ConcatDiskArray, inds...)
98103
end
99104
end
100105

101-
haschunks(::ConcatDiskArray) = Chunked()
106+
haschunks(c::ConcatDiskArray) = c.haschunks
102107

103-
function eachchunk(aconc::ConcatDiskArray{T,N}) where {T,N}
104-
s = size(aconc)
105-
oldchunks = map(eachchunk, aconc.parents)
108+
function concat_chunksize(N, parents)
109+
oldchunks = map(eachchunk, parents)
106110
newchunks = ntuple(N) do i
107111
sliceinds = Base.setindex(ntuple(_ -> 1, N), :, i)
108112
v = map(c -> c.chunks[i], oldchunks[sliceinds...])
@@ -113,6 +117,10 @@ function eachchunk(aconc::ConcatDiskArray{T,N}) where {T,N}
113117
return GridChunks(newchunks...)
114118
end
115119

120+
function eachchunk(aconc::ConcatDiskArray{T,N}) where {T,N}
121+
aconc.chunks
122+
end
123+
116124
function mergechunks(a::RegularChunks, b::RegularChunks)
117125
if a.s == 0 || (a.cs == b.cs && length(last(a)) == a.cs)
118126
RegularChunks(a.cs, a.offset, a.s + b.s)

test/runtests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,15 @@ end
428428
@test sum(ca) == sum(0:23)
429429
end
430430

431+
@testset "Concatenation of unchunked arrays" begin
432+
a = UnchunkedDiskArray(rand(10,20))
433+
b = UnchunkedDiskArray(rand(15,20))
434+
c = UnchunkedDiskArray(rand(18,20))
435+
d = cat(a, b, c; dims=1)
436+
@test DiskArrays.eachchunk(d) == [(1:10, 1:20); (11:25, 1:20); (26:43, 1:20);;]
437+
@test DiskArrays.haschunks(d) isa DiskArrays.Chunked
438+
end
439+
431440
@testset "cat mixed chunk size" begin
432441
a = AccessCountDiskArray(collect(1:10); chunksize=(3,))
433442
b = AccessCountDiskArray(collect(1:9); chunksize=(4,))
@@ -450,6 +459,19 @@ end
450459
@test d == 1:26
451460
@test c == 20:26
452461
end
462+
463+
@testset "ConcatDiskArray works for mixed element types" begin
464+
da1 = AccessCountDiskArray(collect(Float64,reshape(1:24, 4, 6, 1)))
465+
da2 = AccessCountDiskArray(collect(Float32,reshape(1:24, 4, 6, 1)))
466+
@test eltype(da1) <: Float64
467+
@test eltype(da2) <: Float32
468+
c = DiskArrays.ConcatDiskArray([da1, da2])
469+
470+
@test eltype(c) <: Float64
471+
slic = c[:,1,1]
472+
@test slic isa Vector{Float64}
473+
@test slic == Float64[1, 2, 3, 4, 1, 2, 3, 4]
474+
end
453475
end
454476

455477
@testset "Broadcast with length 1 and 0 final dim" begin

0 commit comments

Comments
 (0)