Skip to content

Commit 531174a

Browse files
use stack in batch
1 parent a9e60cb commit 531174a

File tree

1 file changed

+1
-11
lines changed

1 file changed

+1
-11
lines changed

src/utils.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,7 @@ end
286286

287287
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
288288

289-
function batch(xs::AbstractArray{<:AbstractArray})
290-
# Don't use stack(xs, dims=N+1), it is much slower.
291-
# Here we do reduce(vcat, xs) along with some reshapes.
292-
szxs = size(xs)
293-
@assert length(xs) > 0 "Minimum batch size is 1."
294-
szx = size(xs[1])
295-
@assert all(x -> size(x) == szx, xs) "All arrays must be of the same size."
296-
vxs = vec(vec.(xs))
297-
y = reduce(vcat, vxs)
298-
return reshape(y, szx..., szxs...)
299-
end
289+
batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)
300290

301291
function batch(xs::Vector{<:Tuple})
302292
@assert length(xs) > 0 "Input should be non-empty"

0 commit comments

Comments
 (0)