Skip to content

Commit efcd9bb

Browse files
authored
fast path for common_chunks where only one disk array (#224)
* fast path for common_chunks where only one disk array * bugfix tt * fix if block nesting * add chunk equality comparisons * fix == for small RegularChunks
1 parent 74fdfa4 commit efcd9bb

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

src/broadcast.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,24 @@ function eachchunk(a::BroadcastDiskArray)
5252
end
5353
function common_chunks(s, args...)
5454
N = length(s)
55-
chunkedars = filter(i -> haschunks(i) === Chunked(), collect(args))
56-
all(ar -> isa(eachchunk(ar), GridChunks), chunkedars) ||
55+
chunkedarrays = reduce(args; init=()) do acc, x
56+
haschunks(x) === Chunked() ? (acc..., x) : acc
57+
end
58+
all(ar -> isa(eachchunk(ar), GridChunks), chunkedarrays) ||
5759
error("Currently only chunks of type GridChunks can be merged by broadcast")
58-
if isempty(chunkedars)
60+
if isempty(chunkedarrays)
5961
totalsize = sum(sizeof eltype, args)
6062
return estimate_chunksize(s, totalsize)
63+
elseif length(chunkedarrays) == 1
64+
return eachchunk(only(chunkedarrays))
6165
else
62-
allcs = eachchunk.(chunkedars)
66+
allchunks = collect(map(eachchunk, chunkedarrays))
6367
tt = ntuple(N) do n
64-
csnow = filter(allcs) do cs
65-
ndims(cs) >= n && first(first(cs.chunks[n])) < last(last(cs.chunks[n]))
68+
csnow = filter(allchunks) do cs
69+
ndims(cs) >= n && first(first(cs.chunks[n])) < last(last(cs.chunks[n]))
6670
end
6771
isempty(csnow) && return RegularChunks(1, 0, s[n])
72+
6873
cs = first(csnow).chunks[n]
6974
if all(s -> s.chunks[n] == cs, csnow)
7075
return cs

src/chunks.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,28 @@ function Base.getindex(r::RegularChunks, i::Int)
3838
end
3939
Base.size(r::RegularChunks, _) = div(r.s + r.offset - 1, r.cs) + 1
4040
Base.size(r::RegularChunks) = (size(r, 1),)
41+
function Base.:(==)(r1::RegularChunks, r2::RegularChunks)
42+
# The axis sizes must always match
43+
r1.s == r2.s || return false
44+
# The number of chunks must also match
45+
nchunks = length(r1)
46+
nchunks == length(r2) || return false
47+
# But after that we need to take the number of chunks into account
48+
if nchunks > 2
49+
# For longer RegularChunks the offsets and chunk sizes
50+
# must match for the chunks to be the same.
51+
# So we compare them directly rather than iterating all of the ranges
52+
return r1.cs == r2.cs && r1.offset == r2.offset
53+
elseif nchunks == 2
54+
# Smaller RegularChunks can match with different chunk sizes and offsets
55+
# So we compare the ranges
56+
return first(r1) == first(r2) && last(r1) == last(r2)
57+
elseif nchunks == 1
58+
return first(r1) == first(r2)
59+
else
60+
return true
61+
end
62+
end
4163

4264
# DiskArrays interface
4365

@@ -135,6 +157,9 @@ function Base.getindex(r::IrregularChunks, i::Int)
135157
return (r.offsets[i] + 1):r.offsets[i + 1]
136158
end
137159
Base.size(r::IrregularChunks) = (length(r.offsets) - 1,)
160+
Base.:(==)(r1::IrregularChunks, r2::IrregularChunks) =
161+
r1 === r2 || r1.offsets == r2.offsets
162+
138163
function subsetchunks(r::IrregularChunks, subs::UnitRange)
139164
c1 = findchunk(r, first(subs))
140165
c2 = findchunk(r, last(subs))

0 commit comments

Comments
 (0)