Skip to content

Commit c7545fe

Browse files
authored
Merge pull request #13 from MurrellGroup/OUflow
Adding OUflow, and example
2 parents 058b94f + 0e21437 commit c7545fe

File tree

5 files changed

+103
-2
lines changed

5 files changed

+103
-2
lines changed

examples/OU_cat.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using Flowfusion, ForwardBackward, Flux, RandomFeatureMaps, Optimisers, Plots
2+
3+
#Set up a Flux model: X̂1 = model(t,Xt)
4+
struct FModel{A}
5+
layers::A
6+
end
7+
Flux.@layer FModel
8+
function FModel(; embeddim = 128, spacedim = 2, layers = 3)
9+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
10+
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
11+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
12+
decode = Dense(embeddim => spacedim)
13+
layers = (; embed_time, embed_state, ffs, decode)
14+
FModel(layers)
15+
end
16+
17+
function (f::FModel)(t, Xt)
18+
l = f.layers
19+
tXt = tensor(Xt)
20+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
21+
x = l.embed_time(tv) .+ l.embed_state(tXt)
22+
for ff in l.ffs
23+
x = x .+ ff(x)
24+
end
25+
tXt .+ l.decode(x) .* (1.05f0 .- expand(t, ndims(tXt)))
26+
end
27+
28+
#Distributions for training:
29+
T = Float32
30+
sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2
31+
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
32+
n_samples = 400
33+
34+
#Look over three different process settings:
35+
for (θ, v_at_0, v_at_1, dec) in [(10f0, 5f0, 0.01f0, -2f0),(2f0, 2f0, 0.1f0, -2f0),(10f0, 2f0, 0.1f0, -2f0)]
36+
#The process:
37+
P = OUFlow(θ, v_at_0, v_at_1, dec)
38+
39+
#Optimizer:
40+
eta = 0.001
41+
model = FModel(embeddim = 256, layers = 3, spacedim = 2)
42+
opt_state = Flux.setup(AdamW(eta = eta), model)
43+
44+
iters = 4000
45+
for i in 1:iters
46+
#Set up a batch of training pairs, and t:
47+
X0 = ContinuousState(sampleX0(n_samples))
48+
X1 = ContinuousState(sampleX1(n_samples))
49+
t = rand(T, n_samples).*0.999f0
50+
#Construct the bridge:
51+
Xt = bridge(P, X0, X1, t)
52+
#Gradient & update:
53+
l,g = Flux.withgradient(model) do m
54+
floss(P, m(t,Xt), X1, scalefloss(P, t))
55+
end
56+
Flux.update!(opt_state, model, g[1])
57+
(i % 10 == 0) && println("i: $i; Loss: $l")
58+
end
59+
60+
n_inference_samples = 5000
61+
X0 = ContinuousState(sampleX0(n_inference_samples))
62+
paths = Tracker()
63+
samples = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
64+
#Plotting:
65+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
66+
tvec = stack_tracker(paths, :t)
67+
xttraj = stack_tracker(paths, :xt)
68+
for i in 1:50:1000
69+
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = i==1 ? "Trajectory" : :none, alpha = 0.4)
70+
end
71+
X1true = sampleX1(n_inference_samples)
72+
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
73+
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
74+
display(pl)
75+
savefig("OU_continuous_cat_$P.svg")
76+
pl = plot()
77+
for i in 1:50:1000
78+
plot!(xttraj[1,i,:], color = "red", alpha = 0.4, label = :none)
79+
end
80+
pl
81+
savefig("OU_continuous_traj_$P.svg")
82+
end

src/Flowfusion.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ export
3535
InterpolatingDiscreteFlow,
3636
NoisyInterpolatingDiscreteFlow,
3737
DoobMatchingFlow,
38+
OUFlow,
3839
MaskedState,
3940
Guide,
4041
tangent_guide,

src/loss.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@ msu(T) = Union{T, MaskedState{<:T}}
2929

3030
floss(P::fbu(Deterministic), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
3131
floss(P::fbu(BrownianMotion), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
32-
floss(P::fbu(OrnsteinUhlenbeck), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
32+
floss(P::OUFlow, X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) #No schedule (via fbu) for OUFlow. Schedule should not be needed anyway given the direct variance control.
3333
floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
34+
#floss(P::fbu(OrnsteinUhlenbeck), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) #<- I'm not sure MSE on X1 works for this process. We need to pull X1 back to Xt and get the generator.
3435
#For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu.
3536
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) = error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.")
3637
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) = scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁))
3738
floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Union{AbstractArray, Real}) = sum(floss.(P, X̂₁, X₁, (c,)))
3839
floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Tuple) = sum(floss.(P, X̂₁, X₁, c))
3940
floss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ::Guide, c) = scaledmaskedmean(mse(ξhat, ξ.H), c, getlmask(ξ))
4041

42+
4143
#I should make a self-balancing loss that tracks the running mean/std and adaptively scales to balance against target weights.
4244

4345
########################################################################

src/processes.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,9 @@ function step(P::NoisyInterpolatingDiscreteFlow{<:Integer}, Xₜ::DiscreteState{
129129
clamp!(tensor(newXₜ), 0, Inf)
130130
return rand(newXₜ)
131131
end
132+
133+
function bridge(P::OUFlow, X0, X1, t0, t)
134+
OU = OrnsteinUhlenbeckExpVar(tensor(X1), P.θ, P.v_at_0, P.v_at_1, dec = P.dec) #<-Note X1 as mean
135+
endpoint_conditioned_sample(X0, X1, OU, t0, t, eltype(t)(1))
136+
end
137+

src/types.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,14 @@ struct NoisyInterpolatingDiscreteFlow{T} <: ConvexInterpolatingDiscreteFlow
6363
dκ₁::Function # derivative of κ₁
6464
dκ₂::Function # derivative of κ₂
6565
mask_token::T # the token that is used for the X0 state
66-
end
66+
end
67+
68+
#A process where mean to which it reverts is X1
69+
struct OUFlow{T} <: Process
70+
θ::T
71+
v_at_0::T
72+
v_at_1::T
73+
dec::T
74+
end
75+
76+
OUFlow::T, v_at_0::T) where T = OUFlow(θ, v_at_0, T(1e-2), T(-0.1))

0 commit comments

Comments
 (0)