Skip to content

Commit 7ac62b8

Browse files
committed
Examples
1 parent 1a0a741 commit 7ac62b8

File tree

6 files changed

+247
-30143
lines changed

6 files changed

+247
-30143
lines changed

examples/continuous.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ for i in 1:iters
5656
Xt = bridge(P, X0, X1, t)
5757
#Gradient:
5858
l,g = Flux.withgradient(model) do m
59-
floss(P, m(t,Xt), X1, scalefloss(P, t, 2))
59+
floss(P, m(t,Xt), X1, scalefloss(P, t))
6060
end
6161
#Update:
6262
Flux.update!(opt_state, model, g[1])

examples/continuous_cat_BrownianMotion{Float32}(0.0f0, 0.1f0).svg

Lines changed: 0 additions & 15072 deletions
This file was deleted.

examples/continuous_cat_Deterministic().svg

Lines changed: 0 additions & 15070 deletions
This file was deleted.

examples/probabilitysimplex.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using Revise
4+
Pkg.develop(path="../../ForwardBackward/")
5+
Pkg.develop(path="../")
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots, Manifolds
7+
8+
#Set up a Flux model: ξhat = model(t,Xt)
9+
struct PSModel{A}
10+
layers::A
11+
end
12+
Flux.@layer PSModel
13+
function PSModel(; embeddim = 128, l = 2, K = 33, layers = 3)
14+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 2.0f0), Dense(embeddim => embeddim, swish))
15+
embed_char = Dense(K => embeddim, bias = false)
16+
mix = Dense(l*embeddim => embeddim, swish)
17+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
18+
decode = Dense(embeddim => l*(K-1)) #Tangent coord is one less than the number of categories
19+
layers = (; embed_time, embed_char, mix, ffs, decode)
20+
PSModel(layers)
21+
end
22+
23+
function (f::PSModel)(t, Xt)
24+
l = f.layers
25+
tXt = tensor(Xt)
26+
len = size(tXt)[end]
27+
tv = zero(similar(Float32.(tXt), 1, len)) .+ expand(t, 2)
28+
x = l.mix(reshape(l.embed_char(tXt), :, len)) .+ l.embed_time(tv)
29+
for ff in l.ffs
30+
x = x .+ ff(x)
31+
end
32+
return reshape(l.decode(x), :, 2, len) .* (1.05f0 .- expand(t, 3))
33+
end
34+
35+
model = PSModel(embeddim = 256, l = 2, K = 33, layers = 3)
36+
37+
sampleX1(n_samples) = Flowfusion.random_discrete_cat(n_samples)
38+
sampleX0(n_samples) = rand(25:32, 2, n_samples)
39+
40+
T = Float32
41+
n_samples = 200
42+
43+
M = ProbabilitySimplex(32)
44+
P = ManifoldProcess(0.5f0)
45+
46+
eta = 0.01
47+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.01), model)
48+
49+
iters = 5000
50+
for i in 1:iters
51+
#Set up a batch of training pairs, and t
52+
X0 = ManifoldState(T, M, sampleX0(n_samples)) #Note T when constructing a ManifoldState from discrete values
53+
X1 = ManifoldState(T, M, sampleX1(n_samples))
54+
t = rand(T, n_samples)
55+
#Construct the bridge:
56+
Xt = bridge(P, X0, X1, t)
57+
#Get the Xt->X1 tangent coordinates:
58+
ξ = Flowfusion.tangent_coordinates(Xt, X1)
59+
#Gradient:
60+
l,g = Flux.withgradient(model) do m
61+
tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
62+
end
63+
#Update:
64+
Flux.update!(opt_state, model, g[1])
65+
#Log and adjust learning rate:
66+
if i % 10 == 0
67+
if i > iters - 2000
68+
eta *= 0.975
69+
Optimisers.adjust!(opt_state, eta)
70+
end
71+
println("i: $i; Loss: $l; eta: $eta")
72+
end
73+
end
74+
75+
#Generate samples by stepping from X0
76+
n_inference_samples = 5000
77+
X0 = ManifoldState(T, M, sampleX0(n_inference_samples));
78+
paths = Tracker()
79+
X1pred = (t,Xt) -> apply_tangent_coordinates(Xt, model(t,tensor(Xt)))
80+
samp = gen(P, X0, X1pred, 0f0:0.002f0:1f0, tracker = paths)
81+
82+
#Plot the X0 and generated X1:
83+
X0oc = Flowfusion.onecold(tensor(X0))
84+
sampoc = Flowfusion.onecold(tensor(samp))
85+
pl = scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34), title = "X0, X1 (sampled)", titlefontsize = 9)
86+
scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0")
87+
scatter!(sampoc[1,:],sampoc[2,:], msw = 0, color = "green", alpha = 0.02, label = :none)
88+
scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "gen X1")
89+
90+
#...compared to the true X1:
91+
trueX1 = sampleX1(n_inference_samples)
92+
pl = scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34))
93+
scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0")
94+
scatter!(trueX1[1,:],trueX1[2,:], msw = 0, color = "green", alpha = 0.02, label = :none)
95+
scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "true X1")
96+
97+
#Plot a random individual trajectoty:
98+
tvec = stack_tracker(paths, :t)
99+
xttraj = stack_tracker(paths, :xt)
100+
plot(tvec, xttraj[:,1,rand(1:n_samples),:]', legend = :none)
101+
102+
#Animate trajectories as the product of evolving marginals (needs the above code to have run):
103+
x1s = [j for i in 1:33, j in 1:33]
104+
x2s = [i for i in 1:33, j in 1:33]
105+
i = rand(1:n_samples)
106+
gridtraj_vec = [reshape(xttraj[:,1,s,:], 1, 33, :) .* reshape(xttraj[:,2,s,:], 33, 1, :) for s in 1:30]
107+
anim = @animate for i vcat(zeros(Int, 30), ones(Int, 10), collect(1:500), ones(Int, 10).*500, ones(Int, 30).*501)
108+
if i == 0
109+
scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34), title = "X0, X1 (true)", titlefontsize = 9)
110+
scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0")
111+
scatter!(trueX1[1,:],trueX1[2,:], msw = 0, color = "green", alpha = 0.02, label = :none)
112+
scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "true X1")
113+
end
114+
if 1 <= i <= 500
115+
plot(; colorbar = :none, legend = :none, size = (400,400), title = "$(length(gridtraj_vec)) probability simplex trajectories, t = $(round(tvec[i], digits = 2))", titlefontsize = 9, xlim = (0,34), ylim = (0, 34))
116+
for g in gridtraj_vec
117+
scatter!(x1s, x2s, msw = 0, ms = sqrt.(150 .* g[:,:,i]), colorbar = :none, legend = :none, size = (400,400), alpha = 0.4)
118+
end
119+
end
120+
if i > 500
121+
scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34), title = "X0, X1 (sampled)", titlefontsize = 9)
122+
scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0")
123+
scatter!(sampoc[1,:],sampoc[2,:], msw = 0, color = "green", alpha = 0.02, label = :none)
124+
scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "gen X1")
125+
end
126+
end
127+
gif(anim, "probsimplex_$(P).mp4", fps = 30)

examples/torus.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using Revise
4+
Pkg.develop(path="../../ForwardBackward/")
5+
Pkg.develop(path="../")
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
7+
8+
#Set up a Flux model: ξhat = model(t,Xt)
9+
struct TModel{A}
10+
layers::A
11+
end
12+
Flux.@layer TModel
13+
function TModel(; embeddim = 64, spacedim = 2, layers = 5)
14+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 2f0), Dense(embeddim => embeddim, swish))
15+
embed_state = Chain(RandomFourierFeatures(4 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
16+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
17+
decode = Dense(embeddim => spacedim)
18+
layers = (; embed_time, embed_state, ffs, decode)
19+
TModel(layers)
20+
end
21+
22+
function (f::TModel)(t, Xt)
23+
l = f.layers
24+
tXt = tensor(Xt)
25+
enc = vcat(sin.(tXt), cos.(tXt))
26+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
27+
x = l.embed_time(tv) .+ l.embed_state(enc)
28+
for ff in l.ffs
29+
x = x .+ ff(x)
30+
end
31+
return (l.decode(x) .* (1.05f0 .- tv))
32+
end
33+
34+
model = TModel(embeddim = 256, layers = 3, spacedim = 2)
35+
36+
T = Float32
37+
sampleX0(n_samples) = rand(T, 2, n_samples) .+ [2.1f0, 1]
38+
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))[[2,1],:] .* 0.4f0 .- [-0.1f0, 1.3f0]
39+
n_samples = 500
40+
41+
M = Torus(2)
42+
#P = ManifoldProcess(0.2f0)
43+
P = Deterministic()
44+
45+
eta = 0.01
46+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.00001), model)
47+
48+
iters = 8000
49+
for i in 1:iters
50+
#Set up a batch of training pairs, and t
51+
X1 = ManifoldState(M, eachcol(sampleX1(n_samples))) #Note: eachcol
52+
X0 = ManifoldState(M, eachcol(sampleX0(n_samples)))
53+
t = rand(T, n_samples)
54+
#Construct the bridge:
55+
Xt = bridge(P, X0, X1, t)
56+
#Compute the tangent coordinates:
57+
ξ = Flowfusion.tangent_coordinates(Xt, X1)
58+
#Gradient
59+
l,g = Flux.withgradient(model) do m
60+
tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
61+
end
62+
#Update
63+
Flux.update!(opt_state, model, g[1])
64+
#Logging, and lr cooldown:
65+
if i % 10 == 0
66+
if i > iters - 3000
67+
eta *= 0.975
68+
Optimisers.adjust!(opt_state, eta)
69+
end
70+
println("i: $i; Loss: $l; eta: $eta")
71+
end
72+
end
73+
74+
#Generate samples by stepping from X0:
75+
n_inference_samples = 2000
76+
X0 = ManifoldState(M, eachcol(sampleX0(n_inference_samples)))
77+
paths = Tracker()
78+
#We wrap the model, because it was predicting tangent coordinates, not the actual state:
79+
X1pred = (t,Xt) -> apply_tangent_coordinates(Xt, model(t,tensor(Xt)))
80+
samp = gen(P, X0, X1pred, 0f0:0.002f0:1f0, tracker = paths)
81+
82+
#Plot the torus, with samples, and trajectories:
83+
#Project Torus(2) into 3D (just for plotting)
84+
function tor(p; R::Real=2, r::Real=0.5)
85+
u,v = p[1], p[2]
86+
x = (R + r*cos(u)) * cos(v)
87+
y = (R + r*cos(u)) * sin(v)
88+
z = r * sin(u)
89+
return [x, y, z]
90+
end
91+
92+
R = 2
93+
r = 0.5
94+
u = range(0, 2π; length=100)
95+
v = range(0, 2π; length=100)
96+
pl = plot([(R + r*cos(θ))*cos(φ) for θ in u, φ in v], [(R + r*cos(θ))*sin(φ) for θ in u, φ in v], [r*sin(θ) for θ in u, φ in v],
97+
color = "grey", alpha = 0.3, label = :none, camera = (30,30))
98+
torX0 = stack(tor.(eachcol(tensor(X0))))
99+
torSamp = stack(tor.(eachcol(tensor(samp))))
100+
scatter!(torX0[1,:], torX0[2,:], torX0[3,:], label = "X0", msw = 0, ms = 1, color = "blue", alpha = 0.3)
101+
torTarget = stack(tor.(eachcol(sampleX1(n_inference_samples))))
102+
scatter!(torTarget[1,:], torTarget[2,:], torTarget[3,:], label = "X1 (true)", msw = 0, ms = 1, color = "orange", alpha = 0.2)
103+
scatter!(torSamp[1,:], torSamp[2,:], torSamp[3,:], label = "X1 (generated)", msw = 0, ms = 1, color = "green", alpha = 0.3)
104+
tvec = stack_tracker(paths, :t)
105+
xttraj = stack_tracker(paths, :xt)
106+
for i in 1:50:1000
107+
tr = stack(tor.(eachcol(xttraj[:,i,:])))
108+
plot!(tr[1,:], tr[2,:], tr[3,:], color = "red", alpha = 0.3, linewidth = 0.5, label = i == 1 ? "Trajectory" : :none)
109+
end
110+
display(pl)
111+
savefig("torus_$P.svg")

src/Flowfusion.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,13 @@ cat_shape(t) = [-(721*sin(t))/4+196/3*sin(2*t)-86/3*sin(3*t)-131/2*sin(4*t)+477/
2727

2828
random_literal_cat(dims...; sigma = 0.05f0) = typeof(sigma).(stack([cat_shape(rand()*2pi)/200 for _ in zeros(dims...)]) .+ randn(2, dims...) * sigma)
2929

30+
function discretize(x, d, lo, hi)
31+
for (i, v) in enumerate(range(lo, hi, length = d - 1))
32+
x < v && return i
33+
end
34+
d
35+
end
36+
37+
random_discrete_cat(dims...; d = 32, lo = -2.5, hi = 2.5) = discretize.(random_literal_cat(dims...), (d,), (lo,), (hi,))
3038

3139
end

0 commit comments

Comments
 (0)