Skip to content

Commit 40a125c

Browse files
authored
Fixing zerostate
1 parent 5d1f0ab commit 40a125c

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/batching.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@ element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...
99
element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...)
1010

1111
#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
12-
zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))..., expandsize...) .= 0)
13-
zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))..., expandsize...) .= element.K)
12+
zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0)
13+
zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= element.K)
1414
zerostate(element::DiscreteState, expandsize...) = Flowfusion.onehot(DiscreteState(element.K,zeros(Int,expandsize...) .= element.K))
1515
function zerostate(element::T, expandsize...) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}}
16-
newtensor = similar(tensor(element), size(tensor(element))..., expandsize...) .= 0
16+
newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0
1717
for i in 1:manifold_dimension(element.M)
1818
selectdim(selectdim(newtensor, 1,i),1,i) .= 1
1919
end
2020
return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize))))
2121
end
22+
#Pls test this general version with other manifolds? Not sure this will handle the various underlying representations
23+
function zerostate(element::ManifoldState, expandsize...)
24+
newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0
25+
return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize))))
26+
end
2227

2328
#In general, these will be different lengths, so we use an array of arrays as input.
2429
#Doesn't work for onehot states yet.
@@ -39,4 +44,4 @@ function regroup(elarray::AbstractArray{<:AbstractArray})
3944
end
4045
end
4146
return Tuple(newstates)
42-
end
47+
end

0 commit comments

Comments
 (0)