|
| 1 | +using Pkg |
| 2 | +Pkg.activate(".") |
| 3 | +using Revise |
| 4 | +Pkg.develop(path="../../ForwardBackward/") |
| 5 | +Pkg.develop(path="../") |
| 6 | +using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots, Manifolds |
| 7 | + |
| 8 | +#Set up a Flux model: ξhat = model(t,Xt) |
| 9 | +struct PSModel{A} |
| 10 | + layers::A |
| 11 | +end |
| 12 | +Flux.@layer PSModel |
| 13 | +function PSModel(; embeddim = 128, l = 2, K = 33, layers = 3) |
| 14 | + embed_time = Chain(RandomFourierFeatures(1 => embeddim, 2.0f0), Dense(embeddim => embeddim, swish)) |
| 15 | + embed_char = Dense(K => embeddim, bias = false) |
| 16 | + mix = Dense(l*embeddim => embeddim, swish) |
| 17 | + ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers] |
| 18 | + decode = Dense(embeddim => l*(K-1)) #Tangent coord is one less than the number of categories |
| 19 | + layers = (; embed_time, embed_char, mix, ffs, decode) |
| 20 | + PSModel(layers) |
| 21 | +end |
| 22 | + |
| 23 | +function (f::PSModel)(t, Xt) |
| 24 | + l = f.layers |
| 25 | + tXt = tensor(Xt) |
| 26 | + len = size(tXt)[end] |
| 27 | + tv = zero(similar(Float32.(tXt), 1, len)) .+ expand(t, 2) |
| 28 | + x = l.mix(reshape(l.embed_char(tXt), :, len)) .+ l.embed_time(tv) |
| 29 | + for ff in l.ffs |
| 30 | + x = x .+ ff(x) |
| 31 | + end |
| 32 | + return reshape(l.decode(x), :, 2, len) .* (1.05f0 .- expand(t, 3)) |
| 33 | +end |
| 34 | + |
| 35 | +model = PSModel(embeddim = 256, l = 2, K = 33, layers = 3) |
| 36 | + |
| 37 | +sampleX1(n_samples) = Flowfusion.random_discrete_cat(n_samples) |
| 38 | +sampleX0(n_samples) = rand(25:32, 2, n_samples) |
| 39 | + |
| 40 | +T = Float32 |
| 41 | +n_samples = 200 |
| 42 | + |
| 43 | +M = ProbabilitySimplex(32) |
| 44 | +P = ManifoldProcess(0.5f0) |
| 45 | + |
| 46 | +eta = 0.01 |
| 47 | +opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.01), model) |
| 48 | + |
| 49 | +iters = 5000 |
| 50 | +for i in 1:iters |
| 51 | + #Set up a batch of training pairs, and t |
| 52 | + X0 = ManifoldState(T, M, sampleX0(n_samples)) #Note T when constructing a ManifoldState from discrete values |
| 53 | + X1 = ManifoldState(T, M, sampleX1(n_samples)) |
| 54 | + t = rand(T, n_samples) |
| 55 | + #Construct the bridge: |
| 56 | + Xt = bridge(P, X0, X1, t) |
| 57 | + #Get the Xt->X1 tangent coordinates: |
| 58 | + ξ = Flowfusion.tangent_coordinates(Xt, X1) |
| 59 | + #Gradient: |
| 60 | + l,g = Flux.withgradient(model) do m |
| 61 | + tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t)) |
| 62 | + end |
| 63 | + #Update: |
| 64 | + Flux.update!(opt_state, model, g[1]) |
| 65 | + #Log and adjust learning rate: |
| 66 | + if i % 10 == 0 |
| 67 | + if i > iters - 2000 |
| 68 | + eta *= 0.975 |
| 69 | + Optimisers.adjust!(opt_state, eta) |
| 70 | + end |
| 71 | + println("i: $i; Loss: $l; eta: $eta") |
| 72 | + end |
| 73 | +end |
| 74 | + |
| 75 | +#Generate samples by stepping from X0 |
| 76 | +n_inference_samples = 5000 |
| 77 | +X0 = ManifoldState(T, M, sampleX0(n_inference_samples)); |
| 78 | +paths = Tracker() |
| 79 | +X1pred = (t,Xt) -> apply_tangent_coordinates(Xt, model(t,tensor(Xt))) |
| 80 | +samp = gen(P, X0, X1pred, 0f0:0.002f0:1f0, tracker = paths) |
| 81 | + |
| 82 | +#Plot the X0 and generated X1: |
| 83 | +X0oc = Flowfusion.onecold(tensor(X0)) |
| 84 | +sampoc = Flowfusion.onecold(tensor(samp)) |
| 85 | +pl = scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34), title = "X0, X1 (sampled)", titlefontsize = 9) |
| 86 | +scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0") |
| 87 | +scatter!(sampoc[1,:],sampoc[2,:], msw = 0, color = "green", alpha = 0.02, label = :none) |
| 88 | +scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "gen X1") |
| 89 | + |
| 90 | +#...compared to the true X1: |
| 91 | +trueX1 = sampleX1(n_inference_samples) |
| 92 | +pl = scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34)) |
| 93 | +scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0") |
| 94 | +scatter!(trueX1[1,:],trueX1[2,:], msw = 0, color = "green", alpha = 0.02, label = :none) |
| 95 | +scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "true X1") |
| 96 | + |
| 97 | +#Plot a random individual trajectoty: |
| 98 | +tvec = stack_tracker(paths, :t) |
| 99 | +xttraj = stack_tracker(paths, :xt) |
| 100 | +plot(tvec, xttraj[:,1,rand(1:n_samples),:]', legend = :none) |
| 101 | + |
| 102 | +#Animate trajectories as the product of evolving marginals (needs the above code to have run): |
| 103 | +x1s = [j for i in 1:33, j in 1:33] |
| 104 | +x2s = [i for i in 1:33, j in 1:33] |
| 105 | +i = rand(1:n_samples) |
| 106 | +gridtraj_vec = [reshape(xttraj[:,1,s,:], 1, 33, :) .* reshape(xttraj[:,2,s,:], 33, 1, :) for s in 1:30] |
| 107 | +anim = @animate for i ∈ vcat(zeros(Int, 30), ones(Int, 10), collect(1:500), ones(Int, 10).*500, ones(Int, 30).*501) |
| 108 | + if i == 0 |
| 109 | + scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34), title = "X0, X1 (true)", titlefontsize = 9) |
| 110 | + scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0") |
| 111 | + scatter!(trueX1[1,:],trueX1[2,:], msw = 0, color = "green", alpha = 0.02, label = :none) |
| 112 | + scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "true X1") |
| 113 | + end |
| 114 | + if 1 <= i <= 500 |
| 115 | + plot(; colorbar = :none, legend = :none, size = (400,400), title = "$(length(gridtraj_vec)) probability simplex trajectories, t = $(round(tvec[i], digits = 2))", titlefontsize = 9, xlim = (0,34), ylim = (0, 34)) |
| 116 | + for g in gridtraj_vec |
| 117 | + scatter!(x1s, x2s, msw = 0, ms = sqrt.(150 .* g[:,:,i]), colorbar = :none, legend = :none, size = (400,400), alpha = 0.4) |
| 118 | + end |
| 119 | + end |
| 120 | + if i > 500 |
| 121 | + scatter(X0oc[1,:],X0oc[2,:], msw = 0, color = "blue", label = :none, alpha = 0.02, size = (400,400), xlim = (0,34), ylim = (0, 34), title = "X0, X1 (sampled)", titlefontsize = 9) |
| 122 | + scatter!([-10], [-10], msw = 0, color = "blue", alpha = 0.3, label = "X0") |
| 123 | + scatter!(sampoc[1,:],sampoc[2,:], msw = 0, color = "green", alpha = 0.02, label = :none) |
| 124 | + scatter!([-10], [-10], msw = 0, color = "green", alpha = 0.3, label = "gen X1") |
| 125 | + end |
| 126 | +end |
| 127 | +gif(anim, "probsimplex_$(P).mp4", fps = 30) |
0 commit comments