Skip to content

Commit 5d1f0ab

Browse files
committed
commit to branch (flow stuff)
1 parent cca34eb commit 5d1f0ab

File tree

6 files changed

+18
-24
lines changed

6 files changed

+18
-24
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
/docs/Manifest.toml
55
/docs/build/
6+
Manifest*.toml

Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
910
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
1011
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
@@ -13,15 +14,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1314

1415
[compat]
1516
Adapt = "4.1.1"
17+
FillArrays = "1.13.0"
1618
ForwardBackward = "0.1.0"
1719
Manifolds = "0.10.12"
1820
NNlib = "0.9.27"
1921
OneHotArrays = "0.2.6"
2022
StatsBase = "0.34.4"
2123
julia = "1.9"
22-
23-
[extras]
24-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
25-
26-
[targets]
27-
test = ["Test"]

src/Flowfusion.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ Later:
1717
=#
1818

1919

20-
21-
22-
2320
module Flowfusion
2421

2522
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib
@@ -30,6 +27,8 @@ include("bridge.jl")
3027
include("loss.jl")
3128
include("processes.jl")
3229

30+
include("batching.jl")
31+
3332
export
3433
#Processes not in ForwardBackward.jl
3534
InterpolatingDiscreteFlow,

src/batching.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
#States
2-
element(state,seqindex,batchindex) = selectdim(selectdim(state, ndims(state), batchindex), ndims(state)-1, seqindex)
3-
element(S::MaskedState, seqindex, batchindex) = element(S.S, seqindex, batchindex)
4-
element(S::ContinuousState, seqindex, batchindex) = ContinuousState(element(S.state, seqindex, batchindex))
5-
element(S::ManifoldState, seqindex, batchindex) = ManifoldState(S.M, element(S.state, seqindex, batchindex))
6-
element(S::DiscreteState, seqindex, batchindex) = DiscreteState(S.K, element(S.state, seqindex, batchindex))
7-
element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex, batchindex) = element.(S, seqindex, batchindex)
1+
element(state, seqindex) = selectdim(state, ndims(state), seqindex:seqindex)
2+
element(state, seqindex, batchindex) = element(selectdim(state, ndims(state), batchindex), seqindex)
83

9-
#When there isn't a batch dim:
10-
element(state,seqindex) = selectdim(state, ndims(state), seqindex)
11-
element(S::MaskedState, seqindex) = element(S.S, seqindex)
12-
element(S::ContinuousState, seqindex) = ContinuousState(element(S.state, seqindex))
13-
element(S::ManifoldState, seqindex) = ManifoldState(S.M, element(S.state, seqindex))
14-
element(S::DiscreteState, seqindex) = DiscreteState(S.K, element(S.state, seqindex))
15-
element(S::Tuple{Vararg{Flowfusion.UState}}, seqindex) = element.(S, seqindex)
4+
element(S::MaskedState, inds...) = element(S.S, inds...)
5+
element(S::ContinuousState, inds...) = ContinuousState(element(S.state, inds...))
6+
element(S::ManifoldState, inds...) = ManifoldState(S.M, element(S.state, inds...))
7+
element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...))
8+
9+
element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...)
1610

1711
#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
1812
zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))..., expandsize...) .= 0)

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
3+
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
4+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ using ForwardBackward
4444

4545
@testset "Bridge, step" begin
4646

47-
siz = (5,6,7)
47+
siz = (5,6)
4848
XC() = ContinuousState(randn(5, siz...))
4949
XD() = DiscreteState(5, rand(1:5, siz...))
5050
MT = Torus(2)

0 commit comments

Comments
 (0)