Skip to content

Commit aa25783

Browse files
committed
fix: fix issue
1 parent 7e70264 commit aa25783

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/batchs.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ struct ConcatenatedBatch{T<:AbstractArray}
1212
field::T
1313
lengths::Vector{Int}
1414
function ConcatenatedBatch(field::T, lengths::Vector{Int}) where {T<:AbstractArray}
15-
@assert first(lengths) == 0 "got $lengths"
16-
@assert issorted(lengths) "got $lengths"
15+
@assert first(lengths) == 0 "got $lengths"
16+
@assert issorted(lengths) "got $lengths"
1717
@assert last(lengths) == size(field)[end] "got $lengths, size is $( size(field))"
1818
new{T}(field, lengths)
1919
end
@@ -22,7 +22,9 @@ function ConcatenatedBatch((; field)::Batch)
2222
ConcatenatedBatch(cat(field; dims=ndims(field)), vcat([0], field .|> size .|> last |> cumsum))
2323
end
2424
function stack_ConcatenatedBatch(x::AbstractVector{<:ConcatenatedBatch})
25-
field = reduce((x,y) -> cat(x,y; dims=ndims(first(x).field)),x)
25+
field = reduce(x) do a, b
26+
cat(a, b; dims=ndims(a))
27+
end
2628
offsets = vcat([0], getfield.(x, :lengths) .|> last)::Vector{Int} |> cumsum |> DropLast(1)
2729
lengths = vcat([0], reduce(vcat,
2830
zip(getfield.(x, :lengths) |> Map(Drop(1)), offsets) |> Map(((lengths, offset),) -> lengths .+ offset)

0 commit comments

Comments
 (0)