Skip to content

Commit a88bbf3

Browse files
committed
make cat almost type-stable
1 parent e17f243 commit a88bbf3

File tree

1 file changed

+30
-33
lines changed

1 file changed

+30
-33
lines changed

src/cat.jl

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,46 +21,17 @@ end
2121
function ConcatDiskArray(arrays::AbstractArray{<:AbstractArray{<:Any,N},M}) where {N,M}
2222
T = mapreduce(eltype, promote_type, init=eltype(first(arrays)), arrays)
2323

24-
function othersize(x, id)
25-
return (x[1:(id-1)]..., x[(id+1):end]...)
26-
end
2724
if N > M
28-
newshape = (size(arrays)..., ntuple(_ -> 1, N - M)...)
25+
newshape = extenddims(size(arrays), size(first(arrays)))
26+
@show newshape
2927
arrays1 = reshape(arrays, newshape)
3028
D = N
31-
elseif N < M
32-
arrays1 = map(arrays) do a
33-
newshape = (size(a)..., ntuple(_ -> 1, M - N)...)
34-
reshape(a, newshape)
35-
end
36-
D = M
3729
else
3830
arrays1 = arrays
3931
D = M
4032
end
41-
arraysizes = map(size, arrays1)
42-
si = ntuple(D) do id
43-
a = reduce(arraysizes; dims=id, init=ntuple(zero, D)) do i, j
44-
if all(iszero, i)
45-
j
46-
elseif othersize(i, id) == othersize(j, id)
47-
j
48-
else
49-
error("Dimension sizes don't match")
50-
end
51-
end
52-
I = ntuple(D) do i
53-
i == id ? Colon() : 1
54-
end
55-
ari = map(i -> i[id], arraysizes[I...])
56-
sl = sum(ari)
57-
r = cumsum(ari)
58-
pop!(pushfirst!(r, 0))
59-
r .+ 1, sl
60-
end
61-
62-
startinds = map(first, si)
63-
sizes = map(last, si)
33+
startinds, sizes = arraysize_and_startinds(arrays1)
34+
@show startinds, sizes
6435

6536
chunks = concat_chunksize(D, arrays1)
6637
hc = Chunked(batchstrategy(chunks))
@@ -73,9 +44,35 @@ function ConcatDiskArray(arrays::AbstractArray)
7344
error("Arrays don't have the same dimensions")
7445
return error("Should not be reached")
7546
end
47+
extenddims(a::NTuple{N,Int},b::NTuple{M,Int}) where {N,M} = extenddims((a...,1), b)
48+
extenddims(a::NTuple{N,Int},b::NTuple{N,Int}) where {N} = a
7649

7750
Base.size(a::ConcatDiskArray) = a.size
7851

52+
function arraysize_and_startinds(arrays1)
53+
sizes = map(i->zeros(Int,i),size(arrays1))
54+
for i in CartesianIndices(arrays1)
55+
ai = arrays1[i]
56+
sizecur = size(ai)
57+
foreach(sizecur,i.I,sizes) do si, ind, sizeall
58+
if sizeall[ind] == 0
59+
#init the size
60+
sizeall[ind] = si
61+
elseif sizeall[ind] != si
62+
throw(ArgumentError("Array sizes don't form a grid"))
63+
end
64+
end
65+
end
66+
r = map(sizes) do sizeall
67+
pushfirst!(sizeall, 1)
68+
for i in 2:length(sizeall)
69+
sizeall[i] = sizeall[i-1]+sizeall[i]
70+
end
71+
pop!(sizeall)-1,sizeall
72+
end
73+
map(last, r), map(first, r)
74+
end
75+
7976
# DiskArrays interface
8077

8178
eachchunk(a::ConcatDiskArray) = a.chunks

0 commit comments

Comments
 (0)