@@ -3,7 +3,7 @@ Pkg.activate(".")
3
3
using Revise
4
4
Pkg. develop (path= " ../../ForwardBackward/" )
5
5
Pkg. develop (path= " ../" )
6
- using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
6
+ using ForwardBackward, Flowfusion, Flux, RandomFeatureMaps, Optimisers, Plots
7
7
8
8
# Set up a Flux model: X̂1 = model(t,Xt)
9
9
struct FModel{A}
@@ -34,47 +34,51 @@ model = FModel(embeddim = 256, layers = 3, spacedim = 2)
34
34
35
35
# Distributions for training:
36
36
T = Float32
37
- sampleX1 (n_samples) = Flowfusion. random_literal_cat (n_samples, sigma = T (0.05 ))
38
37
sampleX0 (n_samples) = rand (T, 2 , n_samples) .+ 2
39
- n_samples = 200
38
+ sampleX1 (n_samples) = Flowfusion. random_literal_cat (n_samples, sigma = T (0.05 ))
39
+ n_samples = 400
40
40
41
41
# The process:
42
- P = BrownianMotion (0.1f0 )
42
+ P = BrownianMotion (0.15f0 )
43
43
# P = Deterministic()
44
44
45
45
# Optimizer:
46
- eta = 0.01
47
- opt_state = Flux. setup (AdamW (eta = eta, lambda = 0.001 ), model)
46
+ eta = 0.001
47
+ opt_state = Flux. setup (AdamW (eta = eta), model)
48
48
49
- iters = 5000
49
+ iters = 4000
50
50
for i in 1 : iters
51
51
# Set up a batch of training pairs, and t:
52
- X1 = ContinuousState (sampleX1 (n_samples))
53
52
X0 = ContinuousState (sampleX0 (n_samples))
53
+ X1 = ContinuousState (sampleX1 (n_samples))
54
54
t = rand (T, n_samples)
55
55
# Construct the bridge:
56
56
Xt = bridge (P, X0, X1, t)
57
- # Gradient:
57
+ # Gradient & update :
58
58
l,g = Flux. withgradient (model) do m
59
59
floss (P, m (t,Xt), X1, scalefloss (P, t))
60
60
end
61
- # Update:
62
61
Flux. update! (opt_state, model, g[1 ])
63
- # Logging, and lr cooldown:
64
- if i % 10 == 0
65
- if i > iters - 2000
66
- eta *= 0.975
67
- Optimisers. adjust! (opt_state, eta)
68
- end
69
- println (" i: $i ; Loss: $l ; eta: $eta " )
70
- end
62
+ (i % 10 == 0 ) && println (" i: $i ; Loss: $l " )
71
63
end
72
64
65
+ # Generate samples by stepping from X0
66
+ n_inference_samples = 5000
67
+ X0 = ContinuousState (sampleX0 (n_inference_samples))
68
+ samples = gen (P, X0, model, 0f0 : 0.005f0 : 1f0 )
69
+
70
+ # Plotting
71
+ pl = scatter (X0. state[1 ,:],X0. state[2 ,:], msw = 0 , ms = 1 , color = " blue" , alpha = 0.5 , size = (400 ,400 ), legend = :topleft , label = " X0" )
72
+ X1true = sampleX1 (n_inference_samples)
73
+ scatter! (X1true[1 ,:],X1true[2 ,:], msw = 0 , ms = 1 , color = " orange" , alpha = 0.5 , label = " X1 (true)" )
74
+ scatter! (samples. state[1 ,:],samples. state[2 ,:], msw = 0 , ms = 1 , color = " green" , alpha = 0.5 , label = " X1 (generated)" )
75
+
76
+
73
77
# Generate samples by stepping from X0
74
78
n_inference_samples = 5000
75
79
X0 = ContinuousState (sampleX0 (n_inference_samples))
76
80
paths = Tracker ()
77
- samp = gen (P, X0, model, 0f0 : 0.005f0 : 1f0 , tracker = paths)
81
+ samples = gen (P, X0, model, 0f0 : 0.005f0 : 1f0 , tracker = paths)
78
82
79
83
# Plotting:
80
84
pl = scatter (X0. state[1 ,:],X0. state[2 ,:], msw = 0 , ms = 1 , color = " blue" , alpha = 0.5 , size = (400 ,400 ), legend = :topleft , label = " X0" )
@@ -85,6 +89,6 @@ for i in 1:50:1000
85
89
end
86
90
X1true = sampleX1 (n_inference_samples)
87
91
scatter! (X1true[1 ,:],X1true[2 ,:], msw = 0 , ms = 1 , color = " orange" , alpha = 0.5 , label = " X1 (true)" )
88
- scatter! (samp . state[1 ,:],samp . state[2 ,:], msw = 0 , ms = 1 , color = " green" , alpha = 0.5 , label = " X1 (generated)" )
92
+ scatter! (samples . state[1 ,:],samples . state[2 ,:], msw = 0 , ms = 1 , color = " green" , alpha = 0.5 , label = " X1 (generated)" )
89
93
display (pl)
90
94
savefig (" continuous_cat_$P .svg" )
0 commit comments