|
| 1 | +#States |
| 2 | +element(state,seqindex,batchindex) = selectdim(selectdim(state, ndims(state), batchindex), ndims(state)-1, seqindex) |
| 3 | +element(S::MaskedState, seqindex, batchindex) = element(S.S, seqindex, batchindex) |
| 4 | +element(S::ContinuousState, seqindex, batchindex) = ContinuousState(element(S.state, seqindex, batchindex)) |
| 5 | +element(S::ManifoldState, seqindex, batchindex) = ManifoldState(S.M, element(S.state, seqindex, batchindex)) |
| 6 | +element(S::DiscreteState, seqindex, batchindex) = DiscreteState(S.K, element(S.state, seqindex, batchindex)) |
| 7 | +element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex, batchindex) = element.(S, seqindex, batchindex) |
| 8 | + |
| 9 | +#When there isn't a batch dim: |
| 10 | +element(state,seqindex) = selectdim(state, ndims(state), seqindex) |
| 11 | +element(S::MaskedState, seqindex) = element(S.S, seqindex) |
| 12 | +element(S::ContinuousState, seqindex) = ContinuousState(element(S.state, seqindex)) |
| 13 | +element(S::ManifoldState, seqindex) = ManifoldState(S.M, element(S.state, seqindex)) |
| 14 | +element(S::DiscreteState, seqindex) = DiscreteState(S.K, element(S.state, seqindex)) |
| 15 | +element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex) = element.(S, seqindex) |
| 16 | + |
| 17 | +#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think. |
| 18 | +zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))..., expandsize...) .= 0) |
| 19 | +zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))..., expandsize...) .= element.K) |
| 20 | +zerostate(element::DiscreteState, expandsize...) = Flowfusion.onehot(DiscreteState(element.K,zeros(Int,expandsize...) .= element.K)) |
| 21 | +function zerostate(element::T, expandsize...) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}} |
| 22 | + newtensor = similar(tensor(element), size(tensor(element))..., expandsize...) .= 0 |
| 23 | + for i in 1:manifold_dimension(element.M) |
| 24 | + selectdim(selectdim(newtensor, 1,i),1,i) .= 1 |
| 25 | + end |
| 26 | + return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize)))) |
| 27 | +end |
| 28 | + |
| 29 | +#In general, these will be different lengths, so we use an array of arrays as input. |
| 30 | +#Doesn't work for onehot states yet. |
| 31 | +function regroup(elarray::AbstractArray{<:AbstractArray}) |
| 32 | + example_tuple = elarray[1][1] |
| 33 | + maxlen = maximum(length.(elarray)) |
| 34 | + b = length(elarray) |
| 35 | + @show maxlen, b |
| 36 | + newstates = [zerostate(example_tuple[i],maxlen,b) for i in 1:length(example_tuple)] |
| 37 | + for ne in newstates |
| 38 | + @show size(tensor(ne)) |
| 39 | + end |
| 40 | + for i in 1:b |
| 41 | + for j in 1:length(elarray[i]) |
| 42 | + for k in 1:length(example_tuple) |
| 43 | + element(tensor(newstates[k]),j,i) .= tensor(elarray[i][j][k]) |
| 44 | + end |
| 45 | + end |
| 46 | + end |
| 47 | + return Tuple(newstates) |
| 48 | +end |
0 commit comments