Skip to content

Commit fffe72e

Browse files
committed
Add method for chunk with size of chunks as kwarg
1 parent b29bf41 commit fffe72e

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

src/utils.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,28 +158,27 @@ julia> xs[1]
158158
3 8 13 18
159159
```
160160
"""
161-
chunk(x, n::Int) = collect(Iterators.partition(x, ceil(Int, length(x) / n)))
161+
chunk(x; size::Int) = collect(Iterators.partition(x, size))
162+
chunk(x, n::Int) = chunk(x; size = ceil(Int, length(x) / n))
162163

163-
function chunk(x::AbstractArray, n::Int; dims::Int=ndims(x))
164-
idxs = _partition_idxs(x, n, dims)
164+
function chunk(x::AbstractArray; size::Int, dims::Int=ndims(x))
165+
idxs = _partition_idxs(x, size, dims)
165166
[selectdim(x, dims, i) for i in idxs]
166167
end
168+
chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = ceil(Int, size(x, dims) / n), dims)
167169

168-
function _partition_idxs(x, n, dims)
169-
bs = ceil(Int, size(x, dims) / n)
170-
Iterators.partition(axes(x, dims), bs)
171-
end
172-
173-
function rrule(::typeof(chunk), x::AbstractArray, n::Int; dims::Int=ndims(x))
170+
function rrule(::typeof(chunk), x::AbstractArray; size::Int, dims::Int=ndims(x))
174171
# this is the implementation of chunk
175-
idxs = _partition_idxs(x, n, dims)
172+
idxs = _partition_idxs(x, size, dims)
176173
y = [selectdim(x, dims, i) for i in idxs]
177174
valdims = Val(dims)
178-
chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims), NoTangent())
179-
175+
chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims))
176+
180177
return y, chunk_pullback
181178
end
182179

180+
_partition_idxs(x, size, dims) = Iterators.partition(axes(x, dims), size)
181+
183182
# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77
184183
function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim}
185184
i1 = findfirst(dy -> !(dy isa AbstractZero), dys)

test/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ end
111111
n = 2
112112
dims = 2
113113
x = rand(4, 5)
114-
y = chunk(x, 2)
115-
dy = randn!.(collect.(y))
116-
idxs = MLUtils._partition_idxs(x, n, dims)
117-
test_zygote(MLUtils.∇chunk, dy, x, idxs, Val(dims), check_inferred=false)
114+
l = chunk(x, 2)
115+
dl = randn!.(collect.(l))
116+
idxs = MLUtils._partition_idxs(x, ceil(Int, size(x, dims) / n), dims)
117+
test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false)
118118
end
119119

120120
@testset "group_counts" begin

0 commit comments

Comments
 (0)