Skip to content

Commit eb3e8e6

Browse files
committed
Adding example of what not to do
1 parent aead5d9 commit eb3e8e6

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

examples/continuous_huber.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#=
2+
NOTE: This example is a demonstration of what can go wrong when you make a seemingly innocuous change to the loss function.
3+
If the loss is not compatible with the bridging process used to construct training pairs, the model will not learn to sample from the target distribution.
4+
Often the samples will wind up somewhere on the "data manifold", so it can be worth experimenting with, but be aware that you may no longer be learning the target distribution.
5+
=#
6+
using Pkg
7+
Pkg.activate(".")
8+
using Revise
9+
Pkg.develop(path="../../ForwardBackward/")
10+
Pkg.develop(path="../")
11+
using ForwardBackward, Flowfusion, Flux, RandomFeatureMaps, Optimisers, Plots
12+
13+
#Set up a Flux model: X̂1 = model(t,Xt)
14+
struct FModel{A}
15+
layers::A
16+
end
17+
Flux.@layer FModel
18+
function FModel(; embeddim = 128, spacedim = 2, layers = 3)
19+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
20+
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
21+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
22+
decode = Dense(embeddim => spacedim)
23+
layers = (; embed_time, embed_state, ffs, decode)
24+
FModel(layers)
25+
end
26+
27+
function (f::FModel)(t, Xt)
28+
l = f.layers
29+
tXt = tensor(Xt)
30+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
31+
x = l.embed_time(tv) .+ l.embed_state(tXt)
32+
for ff in l.ffs
33+
x = x .+ ff(x)
34+
end
35+
tXt .+ l.decode(x) .* (1.05f0 .- expand(t, ndims(tXt)))
36+
end
37+
38+
T = Float32
39+
#Distributions for training:
40+
sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2
41+
#Note: 95% of the samples should be in the left blob
42+
X1draw() = rand() < 0.95 ? randn(T, 2) .* 0.5f0 .+ [-5, -4] : randn(T, 2) .* 0.5f0 .+ [6, -10]
43+
sampleX1(n_samples) = stack([X1draw() for _ in 1:n_samples])
44+
n_samples = 400
45+
46+
P = BrownianMotion(0.15f0)
47+
48+
#Alternative loss functions:
49+
huber(x) = 2((abs(x) < 1) ? x^2/2 : abs(x) - 1/2)
50+
L2(x) = x^2 #This is the "right" loss for the Brownian motion process
51+
plot(-4:0.01:4, huber, label = "Huber loss", size = (400,400))
52+
plot!(-4:0.01:4, L2, label = "L2")
53+
54+
for (lossname, lossf) in [("Huber", huber), ("L2", L2)]
55+
model = FModel(embeddim = 256, layers = 4, spacedim = 2)
56+
eta = initeta = 0.003
57+
opt_state = Flux.setup(AdamW(eta = eta), model)
58+
59+
iters = 10000
60+
for i in 1:iters
61+
#Set up a batch of training pairs, and t:
62+
X0 = ContinuousState(sampleX0(n_samples))
63+
X1 = ContinuousState(sampleX1(n_samples))
64+
t = rand(T, n_samples)
65+
#Construct the bridge:
66+
Xt = bridge(P, X0, X1, t)
67+
#Gradient & update:
68+
l,g = Flux.withgradient(model) do m
69+
#floss(P, m(t,Xt), X1, scalefloss(P, t)) #Flowfusion.jl default loss - basically L2.
70+
sum(sum(lossf.(m(t,Xt) .- tensor(X1)), dims = 1) .* expand(t .+ 0.05f0, ndims(tensor(X1)))) #Either Huber or L2, scaled with time.
71+
end
72+
Flux.update!(opt_state, model, g[1])
73+
eta = eta - initeta/iters
74+
Optimisers.adjust!(opt_state, eta)
75+
(i % 10 == 0) && println("i: $i; Loss: $l")
76+
end
77+
78+
#Generate samples by stepping from X0
79+
n_inference_samples = 5000
80+
X0 = ContinuousState(sampleX0(n_inference_samples))
81+
samples = gen(P, X0, model, 0f0:0.005f0:1f0)
82+
83+
#Generate samples by stepping from X0
84+
n_inference_samples = 5000
85+
X0 = ContinuousState(sampleX0(n_inference_samples))
86+
paths = Tracker()
87+
samples = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
88+
89+
#Plotting:
90+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
91+
tvec = stack_tracker(paths, :t)
92+
xttraj = stack_tracker(paths, :xt)
93+
for i in 1:50:1000
94+
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = i==1 ? "Trajectory" : :none, alpha = 0.4)
95+
end
96+
97+
ratio = sum(samples.state[1,:] .< 0)/n_inference_samples
98+
99+
X1true = sampleX1(n_inference_samples)
100+
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)", title = "$(lossname): Ratio=$ratio")
101+
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.25, label = "X1 (generated)")
102+
display(pl)
103+
savefig("blob_$(lossname).svg")
104+
end
105+

0 commit comments

Comments
 (0)