We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a9e60cb commit 531174aCopy full SHA for 531174a
src/utils.jl
@@ -286,17 +286,7 @@ end
286
287
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
288
289
-function batch(xs::AbstractArray{<:AbstractArray})
290
- # Don't use stack(xs, dims=N+1), it is much slower.
291
- # Here we do reduce(vcat, xs) along with some reshapes.
292
- szxs = size(xs)
293
- @assert length(xs) > 0 "Minimum batch size is 1."
294
- szx = size(xs[1])
295
- @assert all(x -> size(x) == szx, xs) "All arrays must be of the same size."
296
- vxs = vec(vec.(xs))
297
- y = reduce(vcat, vxs)
298
- return reshape(y, szx..., szxs...)
299
-end
+batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)
300
301
function batch(xs::Vector{<:Tuple})
302
@assert length(xs) > 0 "Input should be non-empty"
0 commit comments