|
1 |
| -#States |
2 |
| -element(state,seqindex,batchindex) = selectdim(selectdim(state, ndims(state), batchindex), ndims(state)-1, seqindex) |
3 |
| -element(S::MaskedState, seqindex, batchindex) = element(S.S, seqindex, batchindex) |
4 |
| -element(S::ContinuousState, seqindex, batchindex) = ContinuousState(element(S.state, seqindex, batchindex)) |
5 |
| -element(S::ManifoldState, seqindex, batchindex) = ManifoldState(S.M, element(S.state, seqindex, batchindex)) |
6 |
| -element(S::DiscreteState, seqindex, batchindex) = DiscreteState(S.K, element(S.state, seqindex, batchindex)) |
7 |
| -element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex, batchindex) = element.(S, seqindex, batchindex) |
| 1 | +element(state, seqindex) = selectdim(state, ndims(state), seqindex:seqindex) |
| 2 | +element(state, seqindex, batchindex) = element(selectdim(state, ndims(state), batchindex), seqindex) |
8 | 3 |
|
9 |
| -#When there isn't a batch dim: |
10 |
| -element(state,seqindex) = selectdim(state, ndims(state), seqindex) |
11 |
| -element(S::MaskedState, seqindex) = element(S.S, seqindex) |
12 |
| -element(S::ContinuousState, seqindex) = ContinuousState(element(S.state, seqindex)) |
13 |
| -element(S::ManifoldState, seqindex) = ManifoldState(S.M, element(S.state, seqindex)) |
14 |
| -element(S::DiscreteState, seqindex) = DiscreteState(S.K, element(S.state, seqindex)) |
15 |
| -element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex) = element.(S, seqindex) |
| 4 | +element(S::MaskedState, inds...) = element(S.S, inds...) |
| 5 | +element(S::ContinuousState, inds...) = ContinuousState(element(S.state, inds...)) |
| 6 | +element(S::ManifoldState, inds...) = ManifoldState(S.M, element(S.state, inds...)) |
| 7 | +element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...)) |
| 8 | + |
| 9 | +element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...) |
16 | 10 |
|
17 | 11 | #Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
|
18 | 12 | zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))..., expandsize...) .= 0)
|
|
0 commit comments