@@ -9,16 +9,21 @@ element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...
9
9
element (S:: Tuple{Vararg{Flowfusion.UState}} , inds... ) = element .(S, inds... )
10
10
11
11
# Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
12
- zerostate (element:: T , expandsize... ) where T <: ContinuousState = T (similar (tensor (element), size (tensor (element))... , expandsize... ) .= 0 )
13
- zerostate (element:: DiscreteState{<:AbstractArray{<:Signed}} , expandsize... ) = DiscreteState (element. K,similar (tensor (element), size (tensor (element))... , expandsize... ) .= element. K)
12
+ zerostate (element:: T , expandsize... ) where T <: ContinuousState = T (similar (tensor (element), size (tensor (element))[ 1 : end - 1 ] . .. , expandsize... ) .= 0 )
13
+ zerostate (element:: DiscreteState{<:AbstractArray{<:Signed}} , expandsize... ) = DiscreteState (element. K,similar (tensor (element), size (tensor (element))[ 1 : end - 1 ] . .. , expandsize... ) .= element. K)
14
14
zerostate (element:: DiscreteState , expandsize... ) = Flowfusion. onehot (DiscreteState (element. K,zeros (Int,expandsize... ) .= element. K))
15
15
function zerostate (element:: T , expandsize... ) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}}
16
- newtensor = similar (tensor (element), size (tensor (element))... , expandsize... ) .= 0
16
+ newtensor = similar (tensor (element), size (tensor (element))[ 1 : end - 1 ] . .. , expandsize... ) .= 0
17
17
for i in 1 : manifold_dimension (element. M)
18
18
selectdim (selectdim (newtensor, 1 ,i),1 ,i) .= 1
19
19
end
20
20
return ManifoldState (element. M, eachslice (newtensor, dims= ntuple (i -> 2 + i, length (expandsize))))
21
21
end
22
+ # Pls test this general version with other manifolds? Not sure this will handle the various underlying representations
23
+ function zerostate (element:: ManifoldState , expandsize... )
24
+ newtensor = similar (tensor (element), size (tensor (element))[1 : end - 1 ]. .. , expandsize... ) .= 0
25
+ return ManifoldState (element. M, eachslice (newtensor, dims= ntuple (i -> 2 + i, length (expandsize))))
26
+ end
22
27
23
28
# In general, these will be different lengths, so we use an array of arrays as input.
24
29
# Doesn't work for onehot states yet.
@@ -39,4 +44,4 @@ function regroup(elarray::AbstractArray{<:AbstractArray})
39
44
end
40
45
end
41
46
return Tuple (newstates)
42
- end
47
+ end
0 commit comments