Skip to content

Commit b2e514e

Browse files
committed
Use cld and fld
Add tests for `chunk` with `size`
1 parent cd96085 commit b2e514e

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

src/batchview.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(no
100100
throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`."))
101101
end
102102
E = _batchviewelemtype(data, collate)
103-
count = partial ? ceil(Int, n / batchsize) : floor(Int, n / batchsize)
103+
count = partial ? cld(n, batchsize) : fld(n, batchsize)
104104
BatchView{E,T,typeof(collate)}(data, batchsize, count, partial)
105105
end
106106

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,13 @@ julia> xes[2]
181181
```
182182
"""
183183
chunk(x; size::Int) = collect(Iterators.partition(x, size))
184-
chunk(x, n::Int) = chunk(x; size = ceil(Int, length(x) / n))
184+
chunk(x, n::Int) = chunk(x; size = cld(length(x), n))
185185

186186
function chunk(x::AbstractArray; size::Int, dims::Int=ndims(x))
187187
idxs = _partition_idxs(x, size, dims)
188188
[selectdim(x, dims, i) for i in idxs]
189189
end
190-
chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = ceil(Int, size(x, dims) / n), dims)
190+
chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = cld(size(x, dims), n), dims)
191191

192192
function rrule(::typeof(chunk), x::AbstractArray; size::Int, dims::Int=ndims(x))
193193
# this is the implementation of chunk

test/utils.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,16 @@ end
101101
x = reshape(collect(1:20), (5, 4))
102102
cs = chunk(x, 2)
103103
@test length(cs) == 2
104-
cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10]
105-
cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20]
106-
104+
@test cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10]
105+
@test cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20]
106+
107+
x = permutedims(reshape(collect(1:10), (2, 5)))
108+
cs = chunk(x; size = 2, dims = 1)
109+
@test length(cs) == 3
110+
@test cs[1] == [1 2; 3 4]
111+
@test cs[2] == [5 6; 7 8]
112+
@test cs[3] == [9 10]
113+
107114
# test gradient
108115
test_zygote(chunk, rand(10), 3, check_inferred=false)
109116

@@ -113,7 +120,7 @@ end
113120
x = rand(4, 5)
114121
l = chunk(x, 2)
115122
dl = randn!.(collect.(l))
116-
idxs = MLUtils._partition_idxs(x, ceil(Int, size(x, dims) / n), dims)
123+
idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims)
117124
test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false)
118125
end
119126

0 commit comments

Comments
 (0)