@@ -3,7 +3,7 @@ Pkg.activate(".")
33using Revise
44Pkg. develop (path= " ../../ForwardBackward/" )
55Pkg. develop (path= " ../" )
6- using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
6+ using ForwardBackward, Flowfusion, Flux, RandomFeatureMaps, Optimisers, Plots
77
88# Set up a Flux model: X̂1 = model(t,Xt)
99struct FModel{A}
@@ -34,47 +34,51 @@ model = FModel(embeddim = 256, layers = 3, spacedim = 2)
3434
3535# Distributions for training:
3636T = Float32
37- sampleX1 (n_samples) = Flowfusion. random_literal_cat (n_samples, sigma = T (0.05 ))
3837sampleX0 (n_samples) = rand (T, 2 , n_samples) .+ 2
39- n_samples = 200
38+ sampleX1 (n_samples) = Flowfusion. random_literal_cat (n_samples, sigma = T (0.05 ))
39+ n_samples = 400
4040
4141# The process:
42- P = BrownianMotion (0.1f0 )
42+ P = BrownianMotion (0.15f0 )
4343# P = Deterministic()
4444
4545# Optimizer:
46- eta = 0.01
47- opt_state = Flux. setup (AdamW (eta = eta, lambda = 0.001 ), model)
46+ eta = 0.001
47+ opt_state = Flux. setup (AdamW (eta = eta), model)
4848
49- iters = 5000
49+ iters = 4000
5050for i in 1 : iters
5151 # Set up a batch of training pairs, and t:
52- X1 = ContinuousState (sampleX1 (n_samples))
5352 X0 = ContinuousState (sampleX0 (n_samples))
53+ X1 = ContinuousState (sampleX1 (n_samples))
5454 t = rand (T, n_samples)
5555 # Construct the bridge:
5656 Xt = bridge (P, X0, X1, t)
57- # Gradient:
57+ # Gradient & update :
5858 l,g = Flux. withgradient (model) do m
5959 floss (P, m (t,Xt), X1, scalefloss (P, t))
6060 end
61- # Update:
6261 Flux. update! (opt_state, model, g[1 ])
63- # Logging, and lr cooldown:
64- if i % 10 == 0
65- if i > iters - 2000
66- eta *= 0.975
67- Optimisers. adjust! (opt_state, eta)
68- end
69- println (" i: $i ; Loss: $l ; eta: $eta " )
70- end
62+ (i % 10 == 0 ) && println (" i: $i ; Loss: $l " )
7163end
7264
65+ # Generate samples by stepping from X0
66+ n_inference_samples = 5000
67+ X0 = ContinuousState (sampleX0 (n_inference_samples))
68+ samples = gen (P, X0, model, 0f0 : 0.005f0 : 1f0 )
69+
70+ # Plotting
71+ pl = scatter (X0. state[1 ,:],X0. state[2 ,:], msw = 0 , ms = 1 , color = " blue" , alpha = 0.5 , size = (400 ,400 ), legend = :topleft , label = " X0" )
72+ X1true = sampleX1 (n_inference_samples)
73+ scatter! (X1true[1 ,:],X1true[2 ,:], msw = 0 , ms = 1 , color = " orange" , alpha = 0.5 , label = " X1 (true)" )
74+ scatter! (samples. state[1 ,:],samples. state[2 ,:], msw = 0 , ms = 1 , color = " green" , alpha = 0.5 , label = " X1 (generated)" )
75+
76+
7377# Generate samples by stepping from X0
7478n_inference_samples = 5000
7579X0 = ContinuousState (sampleX0 (n_inference_samples))
7680paths = Tracker ()
77- samp = gen (P, X0, model, 0f0 : 0.005f0 : 1f0 , tracker = paths)
81+ samples = gen (P, X0, model, 0f0 : 0.005f0 : 1f0 , tracker = paths)
7882
7983# Plotting:
8084pl = scatter (X0. state[1 ,:],X0. state[2 ,:], msw = 0 , ms = 1 , color = " blue" , alpha = 0.5 , size = (400 ,400 ), legend = :topleft , label = " X0" )
@@ -85,6 +89,6 @@ for i in 1:50:1000
8589end
8690X1true = sampleX1 (n_inference_samples)
8791scatter! (X1true[1 ,:],X1true[2 ,:], msw = 0 , ms = 1 , color = " orange" , alpha = 0.5 , label = " X1 (true)" )
88- scatter! (samp . state[1 ,:],samp . state[2 ,:], msw = 0 , ms = 1 , color = " green" , alpha = 0.5 , label = " X1 (generated)" )
92+ scatter! (samples . state[1 ,:],samples . state[2 ,:], msw = 0 , ms = 1 , color = " green" , alpha = 0.5 , label = " X1 (generated)" )
8993display (pl)
9094savefig (" continuous_cat_$P .svg" )
0 commit comments