Skip to content

Commit 1a0a741

Browse files
committed
First code
1 parent 3c64b43 commit 1a0a741

File tree

8 files changed

+30574
-1
lines changed

8 files changed

+30574
-1
lines changed

Project.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,21 @@ uuid = "5b4e93c8-7b6e-4682-b400-fc3b238f52b1"
33
authors = ["murrellb <[email protected]> and contributors"]
44
version = "1.0.0-DEV"
55

6+
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
9+
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
10+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
11+
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
12+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
13+
614
[compat]
15+
Adapt = "4.1.1"
16+
ForwardBackward = "1.0.0"
17+
Manifolds = "0.10.12"
18+
NNlib = "0.9.27"
19+
OneHotArrays = "0.2.6"
20+
StatsBase = "0.34.4"
721
julia = "1.9"
822

923
[extras]

examples/Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[deps]
2+
Flowfusion = "5b4e93c8-7b6e-4682-b400-fc3b238f52b1"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
5+
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
6+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
7+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
8+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
9+
RandomFeatureMaps = "780baa95-dd42-481b-93db-80fe3d88832c"

examples/continuous.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using Revise
4+
Pkg.develop(path="../../ForwardBackward/")
5+
Pkg.develop(path="../")
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
7+
8+
#Set up a Flux model: X̂1 = model(t,Xt)
9+
struct FModel{A}
10+
layers::A
11+
end
12+
Flux.@layer FModel
13+
function FModel(; embeddim = 128, spacedim = 2, layers = 3)
14+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
15+
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
16+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
17+
decode = Dense(embeddim => spacedim)
18+
layers = (; embed_time, embed_state, ffs, decode)
19+
FModel(layers)
20+
end
21+
22+
function (f::FModel)(t, Xt)
23+
l = f.layers
24+
tXt = tensor(Xt)
25+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
26+
x = l.embed_time(tv) .+ l.embed_state(tXt)
27+
for ff in l.ffs
28+
x = x .+ ff(x)
29+
end
30+
tXt .+ l.decode(x) .* (1.05f0 .- expand(t, ndims(tXt)))
31+
end
32+
33+
model = FModel(embeddim = 256, layers = 3, spacedim = 2)
34+
35+
#Distributions for training:
36+
T = Float32
37+
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
38+
sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2
39+
n_samples = 200
40+
41+
#The process:
42+
P = BrownianMotion(0.1f0)
43+
#P = Deterministic()
44+
45+
#Optimizer:
46+
eta = 0.01
47+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.001), model)
48+
49+
iters = 5000
50+
for i in 1:iters
51+
#Set up a batch of training pairs, and t:
52+
X1 = ContinuousState(sampleX1(n_samples))
53+
X0 = ContinuousState(sampleX0(n_samples))
54+
t = rand(T, n_samples)
55+
#Construct the bridge:
56+
Xt = bridge(P, X0, X1, t)
57+
#Gradient:
58+
l,g = Flux.withgradient(model) do m
59+
floss(P, m(t,Xt), X1, scalefloss(P, t, 2))
60+
end
61+
#Update:
62+
Flux.update!(opt_state, model, g[1])
63+
#Logging, and lr cooldown:
64+
if i % 10 == 0
65+
if i > iters - 2000
66+
eta *= 0.975
67+
Optimisers.adjust!(opt_state, eta)
68+
end
69+
println("i: $i; Loss: $l; eta: $eta")
70+
end
71+
end
72+
73+
#Generate samples by stepping from X0
74+
n_inference_samples = 5000
75+
X0 = ContinuousState(sampleX0(n_inference_samples))
76+
paths = Tracker()
77+
samp = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
78+
79+
#Plotting:
80+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
81+
tvec = stack_tracker(paths, :t)
82+
xttraj = stack_tracker(paths, :xt)
83+
for i in 1:50:1000
84+
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = i==1 ? "Trajectory" : :none, alpha = 0.4)
85+
end
86+
X1true = sampleX1(n_inference_samples)
87+
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
88+
scatter!(samp.state[1,:],samp.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
89+
display(pl)
90+
savefig("continuous_cat_$P.svg")

examples/continuous_cat_BrownianMotion{Float32}(0.0f0, 0.1f0).svg

Lines changed: 15072 additions & 0 deletions
Loading

examples/continuous_cat_Deterministic().svg

Lines changed: 15070 additions & 0 deletions
Loading

src/Flowfusion.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,31 @@
11
module Flowfusion
22

3-
# Write your package code here.
3+
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib
4+
5+
include("bridge.jl")
6+
include("loss.jl")
7+
8+
export
9+
MaskedState,
10+
bridge,
11+
scalefloss,
12+
gen,
13+
Tracker,
14+
stack_tracker,
15+
onehot,
16+
FProcess,
17+
tangent_coordinates,
18+
apply_tangent_coordinates,
19+
floss,
20+
tcloss
21+
22+
23+
#Useful for demos etc:
24+
#Define a cat - from https://www.geogebra.org/m/pH8wD3rW
25+
cat_shape(t) = [-(721*sin(t))/4+196/3*sin(2*t)-86/3*sin(3*t)-131/2*sin(4*t)+477/14*sin(5*t)+27*sin(6*t)-29/2*sin(7*t)+68/5*sin(8*t)+1/10*sin(9*t)+23/4*sin(10*t)-19/2*sin(12*t)-85/21*sin(13*t)+2/3*sin(14*t)+27/5*sin(15*t)+7/4*sin(16*t)+17/9*sin(17*t)-4*sin(18*t)-1/2*sin(19*t)+1/6*sin(20*t)+6/7*sin(21*t)-1/8*sin(22*t)+1/3*sin(23*t)+3/2*sin(24*t)+13/5*sin(25*t)+sin(26*t)-2*sin(27*t)+3/5*sin(28*t)-1/5*sin(29*t)+1/5*sin(30*t)+(2337*cos(t))/8-43/5*cos(2*t)+322/5*cos(3*t)-117/5*cos(4*t)-26/5*cos(5*t)-23/3*cos(6*t)+143/4*cos(7*t)-11/4*cos(8*t)-31/3*cos(9*t)-13/4*cos(10*t)-9/2*cos(11*t)+41/20*cos(12*t)+8*cos(13*t)+2/3*cos(14*t)+6*cos(15*t)+17/4*cos(16*t)-3/2*cos(17*t)-29/10*cos(18*t)+11/6*cos(19*t)+12/5*cos(20*t)+3/2*cos(21*t)+11/12*cos(22*t)-4/5*cos(23*t)+cos(24*t)+17/8*cos(25*t)-7/2*cos(26*t)-5/6*cos(27*t)-11/10*cos(28*t)+1/2*cos(29*t)-1/5*cos(30*t),
26+
-(637*sin(t))/2-188/5*sin(2*t)-11/7*sin(3*t)-12/5*sin(4*t)+11/3*sin(5*t)-37/4*sin(6*t)+8/3*sin(7*t)+65/6*sin(8*t)-32/5*sin(9*t)-41/4*sin(10*t)-38/3*sin(11*t)-47/8*sin(12*t)+5/4*sin(13*t)-41/7*sin(14*t)-7/3*sin(15*t)-13/7*sin(16*t)+17/4*sin(17*t)-9/4*sin(18*t)+8/9*sin(19*t)+3/5*sin(20*t)-2/5*sin(21*t)+4/3*sin(22*t)+1/3*sin(23*t)+3/5*sin(24*t)-3/5*sin(25*t)+6/5*sin(26*t)-1/5*sin(27*t)+10/9*sin(28*t)+1/3*sin(29*t)-3/4*sin(30*t)-(125*cos(t))/2-521/9*cos(2*t)-359/3*cos(3*t)+47/3*cos(4*t)-33/2*cos(5*t)-5/4*cos(6*t)+31/8*cos(7*t)+9/10*cos(8*t)-119/4*cos(9*t)-17/2*cos(10*t)+22/3*cos(11*t)+15/4*cos(12*t)-5/2*cos(13*t)+19/6*cos(14*t)+7/4*cos(15*t)+31/4*cos(16*t)-cos(17*t)+11/10*cos(18*t)-2/3*cos(19*t)+13/3*cos(20*t)-5/4*cos(21*t)+2/3*cos(22*t)+1/4*cos(23*t)+5/6*cos(24*t)+3/4*cos(26*t)-1/2*cos(27*t)-1/10*cos(28*t)-1/3*cos(29*t)-1/19*cos(30*t)]
27+
28+
random_literal_cat(dims...; sigma = 0.05f0) = typeof(sigma).(stack([cat_shape(rand()*2pi)/200 for _ in zeros(dims...)]) .+ randn(2, dims...) * sigma)
29+
430

531
end

src/bridge.jl

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#=#####################
2+
Assumptions:
3+
- t ∈ [0,1]. Any behavior can be controlled by manipulating the process parameters.
4+
- FProcess.F is a monotonic function, with F(0) = 0 and F(1) = 1.
5+
- Default sampling steps are FProcess.F(t) with even t intervals [NOTE TO SELF: Intervals should be F(t2)-F(t1)]
6+
=#####################
7+
8+
struct FProcess{A,B}
9+
P::A #Process
10+
F::B #Time transform
11+
end
12+
13+
UProcess = Union{Process,FProcess}
14+
process(P::FProcess) = P.P
15+
process(P::Process) = P
16+
17+
tscale(P::Process, t) = t
18+
tscale(P::FProcess, t) = P.F.(t)
19+
20+
struct MaskedState{A,B,C}
21+
S::A #State
22+
cmask::B #Conditioning mask. 1 = Xt=X1
23+
lmask::C #Loss mask. 1 = included in loss
24+
end
25+
26+
Adapt.adapt_structure(to, S::ForwardBackward.DiscreteState) = ForwardBackward.DiscreteState(S.K, Adapt.adapt(to, S.state))
27+
Adapt.adapt_structure(to, S::ForwardBackward.ContinuousState) = ForwardBackward.ContinuousState(Adapt.adapt(to, S.state))
28+
Adapt.adapt_structure(to, S::ForwardBackward.CategoricalLikelihood) = ForwardBackward.CategoricalLikelihood(Adapt.adapt(to, S.dist), Adapt.adapt(to, S.log_norm_const))
29+
Adapt.adapt_structure(to, MS::MaskedState{<:State}) = MaskedState(Adapt.adapt(to, MS.S), Adapt.adapt(to, MS.cmask), Adapt.adapt(to, MS.lmask))
30+
Adapt.adapt_structure(to, MS::MaskedState{<:CategoricalLikelihood}) = MaskedState(Adapt.adapt(to, MS.S), Adapt.adapt(to, MS.cmask), Adapt.adapt(to, MS.lmask))
31+
Adapt.adapt_structure(to, S::ForwardBackward.ManifoldState) = ForwardBackward.ManifoldState(S.M, Adapt.adapt(to, S.state))
32+
33+
UState = Union{State,MaskedState}
34+
35+
ForwardBackward.tensor(X::MaskedState) = tensor(X.S)
36+
37+
import Base.copy
38+
copy(X::MaskedState) = MaskedState(copy(X.S), copy(X.cmask), copy(X.lmask))
39+
40+
"""
41+
endslices(a,m)
42+
43+
Returns a view of `a` where slices specified by `m` are selected. `m` can be multidimensional, but the dimensions of m must match the last dimensions of `a`.
44+
For example, if `m` is a boolean array, then `size(a)[ndims(a)-ndims(m):end] == size(m)`.
45+
"""
46+
endslices(a,m) = @view a[ntuple(Returns(:),ndims(a)-ndims(m))...,m]
47+
48+
"""
49+
cmask!(Xt_state, X1_state, cmask)
50+
cmask!(Xt, X1)
51+
52+
Applies, in place, a conditioning mask, forcing elements (or slices) of `Xt` to be equal to `X1`, where `cmask` is 1.
53+
"""
54+
function cmask!(Xt_state, X1_state, cmask)
55+
endslices(Xt_state,cmask) .= endslices(X1_state,cmask)
56+
return Xt_state
57+
end
58+
59+
cmask!(Xt_state, X1_state, cmask::Nothing) = Xt_state
60+
cmask!(Xt, X1::State) = Xt
61+
cmask!(Xt, X1::StateLikelihood) = Xt
62+
cmask!(Xt, X1::MaskedState) = cmask!(Xt.S.state, X1.S.state, X1.cmask)
63+
cmask!(Xt, X1::MaskedState{<:CategoricalLikelihood}) = error("Cannot condition on a CategoricalLikelihood")
64+
cmask!(x̂₁::Tuple, x₀::Tuple) = map(cmask!, x̂₁, x₀)
65+
66+
"""
67+
bridge(P, X0, X1, t)
68+
bridge(P, X0, X1, t0, t)
69+
70+
Samples `Xt` at `t` conditioned on `X0` and `X1` under the process `P`. Start time is `t0` (0 if not specified). End time is 1.
71+
If `X1` is a `MaskedState`, then `Xt` will equal `X1` where the conditioning mask `X1.cmask` is 1.
72+
`P`, `X0`, `X1` can also be tuples where the Nth element of `P` will be used for the Nth elements of `X0` and `X1`.
73+
The same `t` and (optionally) `t0` will be used for all elements. If you need a different `t` for each Proces/State, broadcast with `bridge.(P, X0, X1, t0, t)`.
74+
"""
75+
76+
function bridge(P::UProcess, X0::UState, X1, t0, t)
77+
T = eltype(t)
78+
tF = T.(tscale(P,t) .- tscale(P,t0))
79+
tB = T.(tscale(P,1) .- tscale(P,t))
80+
endpoint_conditioned_sample(cmask!(X0,X1), X1, process(P), tF, tB)
81+
end
82+
bridge(P, X0, X1, t) = bridge(P, X0, X1, eltype(t)(0.0), t)
83+
bridge(P::Tuple{Vararg{UProcess}}, X0::Tuple{Vararg{UState}}, X1::Tuple, t0, t) = bridge.(P, X0, X1, (t0,), (t, ))
84+
85+
86+
87+
#copytensor! and predictresolve are used handle the state translation that happens in gen(...).
88+
#We want the user's X̂₁predictor, which is a DL model, to return a plain tensor (since that will be on the GPU, in the loss, etc).
89+
#This means we need to automagically create a State (typical for the continuous case) or Likelihood (typical for the discrete case) from the tensor.
90+
#But the user may return a State in the Discrete case (for massive state spaces with sub-linear sampling), and a Likelihood in the Continuous case (for variance matching models)
91+
#This also needs to handle MaskedStates (needs testing).
92+
#We need: X̂₁ = fix(X̂₁predictor(t, Xₜ))
93+
#Plan: When X̂₁predictor(t, Xₜ) is a State or Likelihood, just pass through.
94+
#When X̂₁predictor(t, Xₜ) is a plain tensor, we apply default conversion rules.
95+
96+
function copytensor!(dest, src)
97+
tensor(dest) .= tensor(src)
98+
return dest
99+
end
100+
#copytensor!(dest::Tuple, src::Tuple) = map(copytensor!, dest, src)
101+
102+
#Tuple broadcast:
103+
resolveprediction(dest::Tuple, src::Tuple) = map(resolveprediction, dest, src)
104+
#Default if X̂₁ is a plain tensor:
105+
resolveprediction(X̂₁, X₀::DiscreteState) = copytensor!(stochastic(X₀), X̂₁) #Returns a Likelihood
106+
resolveprediction(X̂₁, X₀::State) = copytensor!(copy(X₀), X̂₁) #Returns a State - Handles Continuous and Manifold cases
107+
#Passthrough if the user returns a State or Likelihood
108+
resolveprediction(X̂₁::State, X₀) = X̂₁
109+
resolveprediction(X̂₁::State, X₀::State) = X̂₁
110+
resolveprediction(X̂₁::StateLikelihood, X₀) = X̂₁
111+
#####Add MaskedState case(s)######
112+
113+
##################################
114+
115+
116+
117+
"""
118+
gen(P, X0, X̂₁predictor, steps; tracker=Returns(nothing), midpoint = false)
119+
120+
Constructs a sequence of (stochastic) bridges between `X0` and the predicted `X̂₁` under the process `P`.
121+
`P`, `X0`, can also be tuples where the Nth element of `P` will be used for the Nth elements of `X0` and `X̂₁predictor`.
122+
X̂₁predictor is a function that takes `t` (scalar) and `Xₜ` (optionally a tuple) and returns `X̂₁` (a `UState`, a flat tensor with the right shape, or a tuple of either).
123+
If `X0` is a `MaskedState` (or has a ), then anything `X̂₁` will be conditioned on `X0` where the conditioning mask `X0.cmask` is 1.
124+
"""
125+
function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, X̂₁predictor, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false)
126+
Xₜ = copy.(X₀)
127+
for (s₁, s₂) in zip(steps, steps[begin+1:end])
128+
t = midpoint ? (s₁ + s₂) / 2 : t = s₁
129+
X̂₁ = resolveprediction(X̂₁predictor(t, Xₜ), X₀)
130+
cmask!(X̂₁, X₀)
131+
Xₜ = bridge(P, Xₜ, X̂₁, s₁, s₂)
132+
tracker(t, Xₜ, X̂₁)
133+
end
134+
return Xₜ
135+
end
136+
137+
gen(P, X₀, X̂₁predictor, args...; kwargs...) = gen((P,), (X₀,), (t, Xₜ) -> (X̂₁predictor(t[1], Xₜ[1]),), args...; kwargs...)[1]
138+
139+
struct Tracker <: Function
140+
t::Vector
141+
xt::Vector
142+
x̂1::Vector
143+
end
144+
145+
Tracker() = Tracker([], [], [])
146+
147+
function (tracker::Tracker)(t, xt, x̂1)
148+
push!(tracker.t, t)
149+
push!(tracker.xt, xt)
150+
push!(tracker.x̂1, x̂1)
151+
return nothing
152+
end
153+
154+
function stack_tracker(tracker, field; tuple_index = 1)
155+
return stack([tensor(data[tuple_index]) for data in getproperty(tracker, field)])
156+
end

0 commit comments

Comments
 (0)