Skip to content

Commit 9b000c2

Browse files
authored
Update batching.jl
1 parent 40a125c commit 9b000c2

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

src/batching.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
#This will be in Flowfusion:
12
element(state, seqindex) = selectdim(state, ndims(state), seqindex:seqindex)
23
element(state, seqindex, batchindex) = element(selectdim(state, ndims(state), batchindex), seqindex)
3-
44
element(S::MaskedState, inds...) = element(S.S, inds...)
55
element(S::ContinuousState, inds...) = ContinuousState(element(S.state, inds...))
66
element(S::ManifoldState, inds...) = ManifoldState(S.M, element(S.state, inds...))
77
element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...))
8-
98
element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...)
109

10+
1111
#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
1212
zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0)
1313
zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= element.K)
@@ -31,11 +31,7 @@ function regroup(elarray::AbstractArray{<:AbstractArray})
3131
example_tuple = elarray[1][1]
3232
maxlen = maximum(length.(elarray))
3333
b = length(elarray)
34-
@show maxlen, b
3534
newstates = [zerostate(example_tuple[i],maxlen,b) for i in 1:length(example_tuple)]
36-
for ne in newstates
37-
@show size(tensor(ne))
38-
end
3935
for i in 1:b
4036
for j in 1:length(elarray[i])
4137
for k in 1:length(example_tuple)
@@ -44,4 +40,4 @@ function regroup(elarray::AbstractArray{<:AbstractArray})
4440
end
4541
end
4642
return Tuple(newstates)
47-
end
43+
end

0 commit comments

Comments
 (0)