Skip to content

Commit 387b371

Browse files
authored
Merge pull request #112 from theabhirath/chunk-size
Extend `chunk` to take `size` as an argument
2 parents b29bf41 + 1128724 commit 387b371

File tree

4 files changed

+51
-23
lines changed

4 files changed

+51
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <[email protected]> and contributors"]
4-
version = "0.2.9"
4+
version = "0.2.10"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/batchview.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(no
100100
throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`."))
101101
end
102102
E = _batchviewelemtype(data, collate)
103-
count = partial ? ceil(Int, n / batchsize) : floor(Int, n / batchsize)
103+
count = partial ? cld(n, batchsize) : fld(n, batchsize)
104104
BatchView{E,T,typeof(collate)}(data, batchsize, count, partial)
105105
end
106106

src/utils.jl

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,10 @@ unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims)
122122

123123
"""
124124
chunk(x, n; [dims])
125+
chunk(x; [size, dims])
125126
126-
Split `x` into `n` parts. The parts contain the same number of elements
127-
except possibly for the last one that can be smaller.
127+
Split `x` into `n` parts or alternatively, into equal chunks of size `size`. The parts contain
128+
the same number of elements except possibly for the last one that can be smaller.
128129
129130
If `x` is an array, `dims` can be used to specify along which dimension to
130131
split (defaults to the last dimension).
@@ -138,6 +139,14 @@ julia> chunk(1:10, 3)
138139
5:8
139140
9:10
140141
142+
julia> chunk(1:10; size = 2)
143+
5-element Vector{UnitRange{Int64}}:
144+
1:2
145+
3:4
146+
5:6
147+
7:8
148+
9:10
149+
141150
julia> x = reshape(collect(1:20), (5, 4))
142151
5×4 Matrix{Int64}:
143152
1 6 11 16
@@ -156,30 +165,42 @@ julia> xs[1]
156165
1 6 11 16
157166
2 7 12 17
158167
3 8 13 18
168+
169+
julia> xes = chunk(x; size = 2, dims = 2)
170+
2-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}:
171+
[1 6; 2 7; … ; 4 9; 5 10]
172+
[11 16; 12 17; … ; 14 19; 15 20]
173+
174+
julia> xes[2]
175+
5×2 view(::Matrix{Int64}, :, 3:4) with eltype Int64:
176+
11 16
177+
12 17
178+
13 18
179+
14 19
180+
15 20
159181
```
160182
"""
161-
chunk(x, n::Int) = collect(Iterators.partition(x, ceil(Int, length(x) / n)))
183+
chunk(x; size::Int) = collect(Iterators.partition(x, size))
184+
chunk(x, n::Int) = chunk(x; size = cld(length(x), n))
162185

163-
function chunk(x::AbstractArray, n::Int; dims::Int=ndims(x))
164-
idxs = _partition_idxs(x, n, dims)
186+
function chunk(x::AbstractArray; size::Int, dims::Int=ndims(x))
187+
idxs = _partition_idxs(x, size, dims)
165188
[selectdim(x, dims, i) for i in idxs]
166189
end
190+
chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = cld(size(x, dims), n), dims)
167191

168-
function _partition_idxs(x, n, dims)
169-
bs = ceil(Int, size(x, dims) / n)
170-
Iterators.partition(axes(x, dims), bs)
171-
end
172-
173-
function rrule(::typeof(chunk), x::AbstractArray, n::Int; dims::Int=ndims(x))
192+
function rrule(::typeof(chunk), x::AbstractArray; size::Int, dims::Int=ndims(x))
174193
# this is the implementation of chunk
175-
idxs = _partition_idxs(x, n, dims)
194+
idxs = _partition_idxs(x, size, dims)
176195
y = [selectdim(x, dims, i) for i in idxs]
177196
valdims = Val(dims)
178-
chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims), NoTangent())
179-
197+
chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims))
198+
180199
return y, chunk_pullback
181200
end
182201

202+
_partition_idxs(x, size, dims) = Iterators.partition(axes(x, dims), size)
203+
183204
# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77
184205
function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim}
185206
i1 = findfirst(dy -> !(dy isa AbstractZero), dys)

test/utils.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,27 @@ end
101101
x = reshape(collect(1:20), (5, 4))
102102
cs = chunk(x, 2)
103103
@test length(cs) == 2
104-
cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10]
105-
cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20]
106-
104+
@test cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10]
105+
@test cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20]
106+
107+
x = permutedims(reshape(collect(1:10), (2, 5)))
108+
cs = chunk(x; size = 2, dims = 1)
109+
@test length(cs) == 3
110+
@test cs[1] == [1 2; 3 4]
111+
@test cs[2] == [5 6; 7 8]
112+
@test cs[3] == [9 10]
113+
107114
# test gradient
108115
test_zygote(chunk, rand(10), 3, check_inferred=false)
109116

110117
# indirect test of second order derivates
111118
n = 2
112119
dims = 2
113120
x = rand(4, 5)
114-
y = chunk(x, 2)
115-
dy = randn!.(collect.(y))
116-
idxs = MLUtils._partition_idxs(x, n, dims)
117-
test_zygote(MLUtils.∇chunk, dy, x, idxs, Val(dims), check_inferred=false)
121+
l = chunk(x, 2)
122+
dl = randn!.(collect.(l))
123+
idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims)
124+
test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false)
118125
end
119126

120127
@testset "group_counts" begin

0 commit comments

Comments
 (0)