Skip to content

Commit 2c4f0da

Browse files
committed
fix last tests
1 parent 6732210 commit 2c4f0da

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

src/cat.jl

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ Returned from `cat` on disk arrays.
1515
1616
It is also useful on its own as it can easily concatenate an array of disk arrays.
1717
"""
18-
struct ConcatDiskArray{T,N,P,C,HC} <: AbstractDiskArray{T,N}
18+
struct ConcatDiskArray{T,N,P,C,HC,ID} <: AbstractDiskArray{T,N}
1919
parents::P
2020
startinds::NTuple{N,Vector{Int}}
2121
size::NTuple{N,Int}
2222
chunks::C
2323
haschunks::HC
24+
innerdims::Val{ID}
2425
end
2526

2627
function ConcatDiskArray(arrays::AbstractArray{Union{<:AbstractArray,Missing}})
@@ -30,47 +31,57 @@ function ConcatDiskArray(arrays::AbstractArray{Union{<:AbstractArray,Missing}})
3031
M = ndims(et)
3132
_ConcatDiskArray(arrays, T, Val(N), Val(M))
3233
end
33-
function ConcatDiskArray(arrays::AbstractArray{<:AbstractArray})
34-
T = eltype(eltype(arrays))
35-
N = ndims(arrays)
36-
M = ndims(eltype(arrays))
37-
_ConcatDiskArray(arrays, T, Val(N), Val(M))
38-
end
39-
function ConcatDiskArray(arrays::AbstractArray)
40-
N = ndims(arrays)
41-
M, T = foldl(arrays, init=(-1, Union{})) do (M, T), a
34+
function infer_eltypes(arrays)
35+
foldl(arrays, init=(-1, Union{})) do (M, T), a
4236
if ismissing(a)
4337
(M, promote_type(Missing, T))
4438
else
4539
M == -1 || ndims(a) == M || throw(ArgumentError("All arrays to concatenate must have equal ndims"))
4640
(ndims(a), promote_type(eltype(a), T))
4741
end
4842
end
43+
end
44+
function ConcatDiskArray(arrays::AbstractArray{<:AbstractArray})
45+
N = ndims(arrays)
46+
T = eltype(eltype(arrays))
47+
if !isconcretetype(T)
48+
M,T = infer_eltypes(arrays)
49+
else
50+
M = ndims(eltype(arrays))
51+
end
52+
_ConcatDiskArray(arrays, T, Val(N), Val(M))
53+
end
54+
function ConcatDiskArray(arrays::AbstractArray)
55+
N = ndims(arrays)
56+
M,T = infer_eltypes(arrays)
4957
_ConcatDiskArray(arrays, T, Val(N), Val(M))
5058
end
5159

5260

5361
function _ConcatDiskArray(arrays, T, ::Val{N}, ::Val{M}) where {N,M}
54-
if N > M
55-
newshape = extenddims(size(arrays), ntuple(_ -> 1, N), 1)
62+
if N < M
63+
newshape = extenddims(size(arrays), ntuple(_ -> 1, M), 1)
5664
arrays1 = reshape(arrays, newshape)
57-
D = N
65+
D = M
5866
else
5967
arrays1 = arrays
60-
D = M
68+
D = N
6169
end
62-
_ConcatDiskArray(arrays1::AbstractArray, T, Val(D))
70+
ConcatDiskArray(arrays1::AbstractArray, T, Val(D), Val(M))
6371
end
64-
function _ConcatDiskArray(arrays1::AbstractArray, T, ::Val{D}) where {D}
72+
function ConcatDiskArray(arrays1::AbstractArray, T, ::Val{D},::Val{ID}) where {D,ID}
6573
startinds, sizes = arraysize_and_startinds(arrays1)
6674

6775
chunks = concat_chunksize(arrays1)
6876
hc = Chunked(batchstrategy(chunks))
6977

70-
return ConcatDiskArray{T,D,typeof(arrays1),typeof(chunks),typeof(hc)}(arrays1, startinds, sizes, chunks, hc)
78+
return ConcatDiskArray{T,D,typeof(arrays1),typeof(chunks),typeof(hc),ID}(arrays1, startinds, sizes, chunks, hc,Val(ID))
7179
end
7280

73-
extenddims(a::Tuple{Vararg{Any,N}}, b::Tuple{Vararg{Any,M}}, fillval) where {N,M} = extenddims((a..., fillval), b, fillval)
81+
function extenddims(a::Tuple{Vararg{Any,N}}, b::Tuple{Vararg{Any,M}}, fillval) where {N,M}
82+
length(a) > length(b) && error("Wrong")
83+
extenddims((a..., fillval), b, fillval)
84+
end
7485
extenddims(a::Tuple{Vararg{Any,N}}, _::Tuple{Vararg{Any,N}}, _) where {N} = a
7586

7687
Base.size(a::ConcatDiskArray) = a.size
@@ -134,9 +145,12 @@ function writeblock!(a::ConcatDiskArray, aout, inds::AbstractUnitRange...)
134145
end
135146

136147
# Utils
148+
ninnerdims(a::ConcatDiskArray) = ninnerdims(a.innerdims)
149+
ninnerdims(::Val{ID}) where ID = ID
137150

138151
function _concat_diskarray_block_io(f, a::ConcatDiskArray, inds...)
139152
# Find affected blocks and indices in blocks
153+
ID = ninnerdims(a)
140154
blockinds = map(inds, a.startinds, size(a.parents)) do i, si, s
141155
bi1 = max(searchsortedlast(si, first(i)), 1)
142156
bi2 = min(searchsortedfirst(si, last(i) + 1) - 1, s)
@@ -147,15 +161,14 @@ function _concat_diskarray_block_io(f, a::ConcatDiskArray, inds...)
147161
size_inferred = map(a.startinds, size(a), cI.I) do si, sa, ii
148162
ii == length(si) ? sa - si[ii] + 1 : si[ii+1] - si[ii]
149163
end
150-
mysize = extenddims(size_inferred, cI.I, 1)
151-
array_range = map(cI.I, a.startinds, mysize, inds) do ii, si, ms, indstoread
164+
array_range = map(cI.I, a.startinds, size_inferred, inds) do ii, si, ms, indstoread
152165
max(first(indstoread) - si[ii] + 1, 1):min(last(indstoread) - si[ii] + 1, ms)
153166
end
154167
outer_range = map(cI.I, a.startinds, array_range, inds) do ii, si, ar, indstoread
155168
(first(ar)+si[ii]-first(indstoread)):(last(ar)+si[ii]-first(indstoread))
156169
end
157170
#Shorten array range to shape of actual array
158-
array_range = map((i, j) -> j, size_inferred, array_range)
171+
array_range = ntuple(j -> array_range[j], ID)
159172
outer_range = fix_outerrangeshape(outer_range, array_range)
160173
if ismissing(myar)
161174
f(outer_range, array_range, missing)

src/diskarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ function readblock!(a::AbstractArray, aout, r...)
2929
@warn "Using fallback readblock! for array $(typeof(a)). This should not happen but there should be a custom implementation."
3030
end
3131
aout .= view(a, CartesianIndices(r))
32+
nothing
3233
end
3334

3435
"""

0 commit comments

Comments
 (0)