Skip to content

Commit 129e311

Browse files
committed
fix: fixed issue when concatenating concatenated batchs
1 parent 798a5fe commit 129e311

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/batchs.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ function ConcatenatedBatch((; field)::Batch)
2323
end
2424
function stack_ConcatenatedBatch(x::AbstractVector{<:ConcatenatedBatch})
2525
field = cat(getfield.(x, :field), dims=ndims(first(x).field))
26-
offsets = vcat([0], getfield.(x, :lengths) .|> last)::Vector{Int} |> cumsum
27-
lengths = zip(getfield.(x, :lengths), offsets) |> Map((lengths, offset) -> lengths .+ offsets) |> vcat |> collect
26+
offsets = vcat([0], getfield.(x, :lengths) .|> last)::Vector{Int} |> cumsum |> DropLast(1)
27+
lengths = vcat([0], reduce(vcat,
28+
zip(getfield.(x, :lengths) |> Map(Drop(1)), offsets) |> Map(((lengths, offset),) -> lengths .+ offset)
29+
))
30+
2831
ConcatenatedBatch(field, lengths)
2932
end
3033
get_slice(lengths::Vector{Int}, i::Integer) = (lengths[i]+1):lengths[i+1]

0 commit comments

Comments
 (0)