Skip to content

Commit cca34eb

Browse files
authored
Adding some States->Elements and back
WIP
1 parent b4525f7 commit cca34eb

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

src/batching.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#States
2+
element(state,seqindex,batchindex) = selectdim(selectdim(state, ndims(state), batchindex), ndims(state)-1, seqindex)
3+
element(S::MaskedState, seqindex, batchindex) = element(S.S, seqindex, batchindex)
4+
element(S::ContinuousState, seqindex, batchindex) = ContinuousState(element(S.state, seqindex, batchindex))
5+
element(S::ManifoldState, seqindex, batchindex) = ManifoldState(S.M, element(S.state, seqindex, batchindex))
6+
element(S::DiscreteState, seqindex, batchindex) = DiscreteState(S.K, element(S.state, seqindex, batchindex))
7+
element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex, batchindex) = element.(S, seqindex, batchindex)
8+
9+
#When there isn't a batch dim:
10+
element(state,seqindex) = selectdim(state, ndims(state), seqindex)
11+
element(S::MaskedState, seqindex) = element(S.S, seqindex)
12+
element(S::ContinuousState, seqindex) = ContinuousState(element(S.state, seqindex))
13+
element(S::ManifoldState, seqindex) = ManifoldState(S.M, element(S.state, seqindex))
14+
element(S::DiscreteState, seqindex) = DiscreteState(S.K, element(S.state, seqindex))
15+
element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex) = element.(S, seqindex)
16+
17+
#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
18+
zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))..., expandsize...) .= 0)
19+
zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))..., expandsize...) .= element.K)
20+
zerostate(element::DiscreteState, expandsize...) = Flowfusion.onehot(DiscreteState(element.K,zeros(Int,expandsize...) .= element.K))
21+
function zerostate(element::T, expandsize...) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}}
22+
newtensor = similar(tensor(element), size(tensor(element))..., expandsize...) .= 0
23+
for i in 1:manifold_dimension(element.M)
24+
selectdim(selectdim(newtensor, 1,i),1,i) .= 1
25+
end
26+
return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize))))
27+
end
28+
29+
#In general, these will be different lengths, so we use an array of arrays as input.
30+
#Doesn't work for onehot states yet.
31+
function regroup(elarray::AbstractArray{<:AbstractArray})
32+
example_tuple = elarray[1][1]
33+
maxlen = maximum(length.(elarray))
34+
b = length(elarray)
35+
@show maxlen, b
36+
newstates = [zerostate(example_tuple[i],maxlen,b) for i in 1:length(example_tuple)]
37+
for ne in newstates
38+
@show size(tensor(ne))
39+
end
40+
for i in 1:b
41+
for j in 1:length(elarray[i])
42+
for k in 1:length(example_tuple)
43+
element(tensor(newstates[k]),j,i) .= tensor(elarray[i][j][k])
44+
end
45+
end
46+
end
47+
return Tuple(newstates)
48+
end

0 commit comments

Comments
 (0)