Skip to content

Commit 48a6005

Browse files
committed
Fixing broken tests due to mask meaning change
1 parent aad2095 commit 48a6005

File tree

2 files changed

+32
-28
lines changed

2 files changed

+32
-28
lines changed

examples/continuous.jl

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Pkg.activate(".")
33
using Revise
44
Pkg.develop(path="../../ForwardBackward/")
55
Pkg.develop(path="../")
6-
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
6+
using ForwardBackward, Flowfusion, Flux, RandomFeatureMaps, Optimisers, Plots
77

88
#Set up a Flux model: X̂1 = model(t,Xt)
99
struct FModel{A}
@@ -34,47 +34,51 @@ model = FModel(embeddim = 256, layers = 3, spacedim = 2)
3434

3535
#Distributions for training:
3636
T = Float32
37-
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
3837
sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2
39-
n_samples = 200
38+
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
39+
n_samples = 400
4040

4141
#The process:
42-
P = BrownianMotion(0.1f0)
42+
P = BrownianMotion(0.15f0)
4343
#P = Deterministic()
4444

4545
#Optimizer:
46-
eta = 0.01
47-
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.001), model)
46+
eta = 0.001
47+
opt_state = Flux.setup(AdamW(eta = eta), model)
4848

49-
iters = 5000
49+
iters = 4000
5050
for i in 1:iters
5151
#Set up a batch of training pairs, and t:
52-
X1 = ContinuousState(sampleX1(n_samples))
5352
X0 = ContinuousState(sampleX0(n_samples))
53+
X1 = ContinuousState(sampleX1(n_samples))
5454
t = rand(T, n_samples)
5555
#Construct the bridge:
5656
Xt = bridge(P, X0, X1, t)
57-
#Gradient:
57+
#Gradient & update:
5858
l,g = Flux.withgradient(model) do m
5959
floss(P, m(t,Xt), X1, scalefloss(P, t))
6060
end
61-
#Update:
6261
Flux.update!(opt_state, model, g[1])
63-
#Logging, and lr cooldown:
64-
if i % 10 == 0
65-
if i > iters - 2000
66-
eta *= 0.975
67-
Optimisers.adjust!(opt_state, eta)
68-
end
69-
println("i: $i; Loss: $l; eta: $eta")
70-
end
62+
(i % 10 == 0) && println("i: $i; Loss: $l")
7163
end
7264

65+
#Generate samples by stepping from X0
66+
n_inference_samples = 5000
67+
X0 = ContinuousState(sampleX0(n_inference_samples))
68+
samples = gen(P, X0, model, 0f0:0.005f0:1f0)
69+
70+
#Plotting
71+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
72+
X1true = sampleX1(n_inference_samples)
73+
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
74+
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
75+
76+
7377
#Generate samples by stepping from X0
7478
n_inference_samples = 5000
7579
X0 = ContinuousState(sampleX0(n_inference_samples))
7680
paths = Tracker()
77-
samp = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
81+
samples = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
7882

7983
#Plotting:
8084
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
@@ -85,6 +89,6 @@ for i in 1:50:1000
8589
end
8690
X1true = sampleX1(n_inference_samples)
8791
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
88-
scatter!(samp.state[1,:],samp.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
92+
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
8993
display(pl)
9094
savefig("continuous_cat_$P.svg")

test/runtests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using ForwardBackward
1616
XDL() = CategoricalLikelihood(rand(5, siz...))
1717
XGL() = GaussianLikelihood(randn(5, siz...), randn(5, siz...), zeros(siz...))
1818

19-
for f in [XC, XD, XT, XR, XDL, XGL]
19+
for f in [XC, XD, XT, XR, XDL, XGL]
2020
Xa = f()
2121
Xb = f()
2222
Xc = Flowfusion.mask(Xa, Xb)
@@ -30,15 +30,15 @@ using ForwardBackward
3030

3131
@test typeof(Xc) == typeof(XM) #If you mask a regular State with a MaskedState, the result is a MaskedState.
3232
d = (tensor(Xb) .- tensor(Xc))
33-
@test isapprox(sum(d .* expand(m, ndims(d))),0)
33+
@test isapprox(sum(d .* expand(.!m, ndims(d))),0)
3434

3535
m = rand(Bool, siz...)
3636
Xa = MaskedState(f(), m, m)
3737
Xb = MaskedState(f(), m, m)
3838
Xc = Flowfusion.mask(Xa, Xb)
3939
@test typeof(Xc) == typeof(Xa)
4040
d = (tensor(Xb) .- tensor(Xc))
41-
@test isapprox(sum(d .* expand(m, ndims(d))),0)
41+
@test isapprox(sum(d .* expand(.!m, ndims(d))),0)
4242
end
4343
end
4444

@@ -62,11 +62,11 @@ using ForwardBackward
6262
m = rand(Bool, siz...)
6363
XM = MaskedState(Xb, m, m)
6464
Xt = Flowfusion.bridge(p, Xa, XM, 0.1)
65-
@assert typeof(Xt) == typeof(XM)
65+
@test typeof(Xt) == typeof(XM)
6666
if !(p isa InterpolatingDiscreteFlow)
67-
@assert isapprox(sum((tensor(Xt) .== tensor(Xb))), sum(m) * (length(tensor(Xb)) / length(m)))
67+
@test isapprox(sum((tensor(Xt) .== tensor(Xb))), sum(.!m) * (length(tensor(Xb)) / length(m)))
6868
else
69-
@assert sum((tensor(Xt) .== tensor(Xb))) >= sum(m) * (length(tensor(Xb)) / length(m))
69+
@test sum((tensor(Xt) .== tensor(Xb))) >= sum(.!m) * (length(tensor(Xb)) / length(.!m))
7070
end
7171

7272
#step - doesn't propogate the mask
@@ -76,10 +76,10 @@ using ForwardBackward
7676
XM = MaskedState(Xa, m, m)
7777
if !(p isa InterpolatingDiscreteFlow)
7878
Xt = Flowfusion.step(p, XM, Xa, 0.1, 0.1)
79-
@assert isapprox(sum(tensor(Xt) .!= tensor(XM)), 0) #Because step size is zero
79+
@test isapprox(sum(tensor(Xt) .!= tensor(XM)), 0) #Because step size is zero
8080
else
8181
Xt = Flowfusion.step(p, XM, onehot(Xa), 0.1, 0.1)
82-
@assert isapprox(sum(tensor(Xt) .!= tensor(XM)), 0) #Because step size is zero
82+
@test isapprox(sum(tensor(Xt) .!= tensor(XM)), 0) #Because step size is zero
8383
end
8484
end
8585

0 commit comments

Comments
 (0)