1
+ # This will be in Flowfusion:
1
2
element (state, seqindex) = selectdim (state, ndims (state), seqindex: seqindex)
2
3
element (state, seqindex, batchindex) = element (selectdim (state, ndims (state), batchindex), seqindex)
3
-
4
4
element (S:: MaskedState , inds... ) = element (S. S, inds... )
5
5
element (S:: ContinuousState , inds... ) = ContinuousState (element (S. state, inds... ))
6
6
element (S:: ManifoldState , inds... ) = ManifoldState (S. M, element (S. state, inds... ))
7
7
element (S:: DiscreteState , inds... ) = DiscreteState (S. K, element (S. state, inds... ))
8
-
9
8
element (S:: Tuple{Vararg{Flowfusion.UState}} , inds... ) = element .(S, inds... )
10
9
10
+
11
11
# Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
12
12
zerostate (element:: T , expandsize... ) where T <: ContinuousState = T (similar (tensor (element), size (tensor (element))[1 : end - 1 ]. .. , expandsize... ) .= 0 )
13
13
zerostate (element:: DiscreteState{<:AbstractArray{<:Signed}} , expandsize... ) = DiscreteState (element. K,similar (tensor (element), size (tensor (element))[1 : end - 1 ]. .. , expandsize... ) .= element. K)
@@ -31,11 +31,7 @@ function regroup(elarray::AbstractArray{<:AbstractArray})
31
31
example_tuple = elarray[1 ][1 ]
32
32
maxlen = maximum (length .(elarray))
33
33
b = length (elarray)
34
- @show maxlen, b
35
34
newstates = [zerostate (example_tuple[i],maxlen,b) for i in 1 : length (example_tuple)]
36
- for ne in newstates
37
- @show size (tensor (ne))
38
- end
39
35
for i in 1 : b
40
36
for j in 1 : length (elarray[i])
41
37
for k in 1 : length (example_tuple)
@@ -44,4 +40,4 @@ function regroup(elarray::AbstractArray{<:AbstractArray})
44
40
end
45
41
end
46
42
return Tuple (newstates)
47
- end
43
+ end
0 commit comments