@@ -122,9 +122,10 @@ unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims)
122
122
123
123
"""
124
124
chunk(x, n; [dims])
125
+ chunk(x; [size, dims])
125
126
126
- Split `x` into `n` parts. The parts contain the same number of elements
127
- except possibly for the last one that can be smaller.
127
+ Split `x` into `n` parts or alternatively, into equal chunks of size `size`. The parts contain
128
+ the same number of elements except possibly for the last one that can be smaller.
128
129
129
130
If `x` is an array, `dims` can be used to specify along which dimension to
130
131
split (defaults to the last dimension).
@@ -138,6 +139,14 @@ julia> chunk(1:10, 3)
138
139
5:8
139
140
9:10
140
141
142
+ julia> chunk(1:10; size = 2)
143
+ 5-element Vector{UnitRange{Int64}}:
144
+ 1:2
145
+ 3:4
146
+ 5:6
147
+ 7:8
148
+ 9:10
149
+
141
150
julia> x = reshape(collect(1:20), (5, 4))
142
151
5×4 Matrix{Int64}:
143
152
1 6 11 16
@@ -156,30 +165,42 @@ julia> xs[1]
156
165
1 6 11 16
157
166
2 7 12 17
158
167
3 8 13 18
168
+
169
+ julia> xes = chunk(x; size = 2, dims = 2)
170
+ 2-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}:
171
+ [1 6; 2 7; … ; 4 9; 5 10]
172
+ [11 16; 12 17; … ; 14 19; 15 20]
173
+
174
+ julia> xes[2]
175
+ 5×2 view(::Matrix{Int64}, :, 3:4) with eltype Int64:
176
+ 11 16
177
+ 12 17
178
+ 13 18
179
+ 14 19
180
+ 15 20
159
181
```
160
182
"""
161
- chunk (x, n:: Int ) = collect (Iterators. partition (x, ceil (Int, length (x) / n)))
183
+ chunk (x; size:: Int ) = collect (Iterators. partition (x, size))
184
+ chunk (x, n:: Int ) = chunk (x; size = cld (length (x), n))
162
185
163
- function chunk (x:: AbstractArray , n :: Int ; dims:: Int = ndims (x))
164
- idxs = _partition_idxs (x, n , dims)
186
+ function chunk (x:: AbstractArray ; size :: Int , dims:: Int = ndims (x))
187
+ idxs = _partition_idxs (x, size , dims)
165
188
[selectdim (x, dims, i) for i in idxs]
166
189
end
190
+ chunk (x:: AbstractArray , n:: Int ; dims:: Int = ndims (x)) = chunk (x; size = cld (size (x, dims), n), dims)
167
191
168
- function _partition_idxs (x, n, dims)
169
- bs = ceil (Int, size (x, dims) / n)
170
- Iterators. partition (axes (x, dims), bs)
171
- end
172
-
173
- function rrule (:: typeof (chunk), x:: AbstractArray , n:: Int ; dims:: Int = ndims (x))
192
+ function rrule (:: typeof (chunk), x:: AbstractArray ; size:: Int , dims:: Int = ndims (x))
174
193
# this is the implementation of chunk
175
- idxs = _partition_idxs (x, n , dims)
194
+ idxs = _partition_idxs (x, size , dims)
176
195
y = [selectdim (x, dims, i) for i in idxs]
177
196
valdims = Val (dims)
178
- chunk_pullback (dy) = (NoTangent (), ∇chunk (unthunk (dy), x, idxs, valdims), NoTangent () )
179
-
197
+ chunk_pullback (dy) = (NoTangent (), ∇chunk (unthunk (dy), x, idxs, valdims))
198
+
180
199
return y, chunk_pullback
181
200
end
182
201
202
+ _partition_idxs (x, size, dims) = Iterators. partition (axes (x, dims), size)
203
+
183
204
# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77
184
205
function ∇chunk (dys, x:: AbstractArray , idxs, vd:: Val{dim} ) where {dim}
185
206
i1 = findfirst (dy -> ! (dy isa AbstractZero), dys)
0 commit comments