|
| 1 | +#= |
| 2 | +NOTE: This example is a demonstration of what can go wrong when you make a seemingly innocuous change to the loss function. |
| 3 | +If the loss is not compatible with the bridging process used to construct training pairs, the model will not learn to sample from the target distribution. |
| 4 | +Often the samples will wind up somewhere on the "data manifold", so it can be worth experimenting with, but be aware that you may no longer be learning the target distribution. |
| 5 | +=# |
| 6 | +using Pkg |
| 7 | +Pkg.activate(".") |
| 8 | +using Revise |
| 9 | +Pkg.develop(path="../../ForwardBackward/") |
| 10 | +Pkg.develop(path="../") |
| 11 | +using ForwardBackward, Flowfusion, Flux, RandomFeatureMaps, Optimisers, Plots |
| 12 | + |
| 13 | +#Set up a Flux model: X̂1 = model(t,Xt) |
| 14 | +struct FModel{A} |
| 15 | + layers::A |
| 16 | +end |
| 17 | +Flux.@layer FModel |
| 18 | +function FModel(; embeddim = 128, spacedim = 2, layers = 3) |
| 19 | + embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish)) |
| 20 | + embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish)) |
| 21 | + ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers] |
| 22 | + decode = Dense(embeddim => spacedim) |
| 23 | + layers = (; embed_time, embed_state, ffs, decode) |
| 24 | + FModel(layers) |
| 25 | +end |
| 26 | + |
| 27 | +function (f::FModel)(t, Xt) |
| 28 | + l = f.layers |
| 29 | + tXt = tensor(Xt) |
| 30 | + tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt)) |
| 31 | + x = l.embed_time(tv) .+ l.embed_state(tXt) |
| 32 | + for ff in l.ffs |
| 33 | + x = x .+ ff(x) |
| 34 | + end |
| 35 | + tXt .+ l.decode(x) .* (1.05f0 .- expand(t, ndims(tXt))) |
| 36 | +end |
| 37 | + |
| 38 | +T = Float32 |
| 39 | +#Distributions for training: |
| 40 | +sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2 |
| 41 | +#Note: 95% of the samples should be in the left blob |
| 42 | +X1draw() = rand() < 0.95 ? randn(T, 2) .* 0.5f0 .+ [-5, -4] : randn(T, 2) .* 0.5f0 .+ [6, -10] |
| 43 | +sampleX1(n_samples) = stack([X1draw() for _ in 1:n_samples]) |
| 44 | +n_samples = 400 |
| 45 | + |
| 46 | +P = BrownianMotion(0.15f0) |
| 47 | + |
| 48 | +#Alternative loss functions: |
| 49 | +huber(x) = 2((abs(x) < 1) ? x^2/2 : abs(x) - 1/2) |
| 50 | +L2(x) = x^2 #This is the "right" loss for the Brownian motion process |
| 51 | +plot(-4:0.01:4, huber, label = "Huber loss", size = (400,400)) |
| 52 | +plot!(-4:0.01:4, L2, label = "L2") |
| 53 | + |
| 54 | +for (lossname, lossf) in [("Huber", huber), ("L2", L2)] |
| 55 | + model = FModel(embeddim = 256, layers = 4, spacedim = 2) |
| 56 | + eta = initeta = 0.003 |
| 57 | + opt_state = Flux.setup(AdamW(eta = eta), model) |
| 58 | + |
| 59 | + iters = 10000 |
| 60 | + for i in 1:iters |
| 61 | + #Set up a batch of training pairs, and t: |
| 62 | + X0 = ContinuousState(sampleX0(n_samples)) |
| 63 | + X1 = ContinuousState(sampleX1(n_samples)) |
| 64 | + t = rand(T, n_samples) |
| 65 | + #Construct the bridge: |
| 66 | + Xt = bridge(P, X0, X1, t) |
| 67 | + #Gradient & update: |
| 68 | + l,g = Flux.withgradient(model) do m |
| 69 | + #floss(P, m(t,Xt), X1, scalefloss(P, t)) #Flowfusion.jl default loss - basically L2. |
| 70 | + sum(sum(lossf.(m(t,Xt) .- tensor(X1)), dims = 1) .* expand(t .+ 0.05f0, ndims(tensor(X1)))) #Either Huber or L2, scaled with time. |
| 71 | + end |
| 72 | + Flux.update!(opt_state, model, g[1]) |
| 73 | + eta = eta - initeta/iters |
| 74 | + Optimisers.adjust!(opt_state, eta) |
| 75 | + (i % 10 == 0) && println("i: $i; Loss: $l") |
| 76 | + end |
| 77 | + |
| 78 | + #Generate samples by stepping from X0 |
| 79 | + n_inference_samples = 5000 |
| 80 | + X0 = ContinuousState(sampleX0(n_inference_samples)) |
| 81 | + samples = gen(P, X0, model, 0f0:0.005f0:1f0) |
| 82 | + |
| 83 | + #Generate samples by stepping from X0 |
| 84 | + n_inference_samples = 5000 |
| 85 | + X0 = ContinuousState(sampleX0(n_inference_samples)) |
| 86 | + paths = Tracker() |
| 87 | + samples = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths) |
| 88 | + |
| 89 | + #Plotting: |
| 90 | + pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0") |
| 91 | + tvec = stack_tracker(paths, :t) |
| 92 | + xttraj = stack_tracker(paths, :xt) |
| 93 | + for i in 1:50:1000 |
| 94 | + plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = i==1 ? "Trajectory" : :none, alpha = 0.4) |
| 95 | + end |
| 96 | + |
| 97 | + ratio = sum(samples.state[1,:] .< 0)/n_inference_samples |
| 98 | + |
| 99 | + X1true = sampleX1(n_inference_samples) |
| 100 | + scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)", title = "$(lossname): Ratio=$ratio") |
| 101 | + scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.25, label = "X1 (generated)") |
| 102 | + display(pl) |
| 103 | + savefig("blob_$(lossname).svg") |
| 104 | +end |
| 105 | + |
0 commit comments