@@ -158,28 +158,27 @@ julia> xs[1]
158
158
3 8 13 18
159
159
```
160
160
"""
161
- chunk (x, n:: Int ) = collect (Iterators. partition (x, ceil (Int, length (x) / n)))
161
+ chunk (x; size:: Int ) = collect (Iterators. partition (x, size))
162
+ chunk (x, n:: Int ) = chunk (x; size = ceil (Int, length (x) / n))
162
163
163
- function chunk (x:: AbstractArray , n :: Int ; dims:: Int = ndims (x))
164
- idxs = _partition_idxs (x, n , dims)
164
+ function chunk (x:: AbstractArray ; size :: Int , dims:: Int = ndims (x))
165
+ idxs = _partition_idxs (x, size , dims)
165
166
[selectdim (x, dims, i) for i in idxs]
166
167
end
168
+ chunk (x:: AbstractArray , n:: Int ; dims:: Int = ndims (x)) = chunk (x; size = ceil (Int, size (x, dims) / n), dims)
167
169
168
- function _partition_idxs (x, n, dims)
169
- bs = ceil (Int, size (x, dims) / n)
170
- Iterators. partition (axes (x, dims), bs)
171
- end
172
-
173
- function rrule (:: typeof (chunk), x:: AbstractArray , n:: Int ; dims:: Int = ndims (x))
170
+ function rrule (:: typeof (chunk), x:: AbstractArray ; size:: Int , dims:: Int = ndims (x))
174
171
# this is the implementation of chunk
175
- idxs = _partition_idxs (x, n , dims)
172
+ idxs = _partition_idxs (x, size , dims)
176
173
y = [selectdim (x, dims, i) for i in idxs]
177
174
valdims = Val (dims)
178
- chunk_pullback (dy) = (NoTangent (), ∇chunk (unthunk (dy), x, idxs, valdims), NoTangent () )
179
-
175
+ chunk_pullback (dy) = (NoTangent (), ∇chunk (unthunk (dy), x, idxs, valdims))
176
+
180
177
return y, chunk_pullback
181
178
end
182
179
180
+ _partition_idxs (x, size, dims) = Iterators. partition (axes (x, dims), size)
181
+
183
182
# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77
184
183
function ∇chunk (dys, x:: AbstractArray , idxs, vd:: Val{dim} ) where {dim}
185
184
i1 = findfirst (dy -> ! (dy isa AbstractZero), dys)
0 commit comments