Skip to content

Commit e4ee29d

Browse files
committed
More examples
1 parent aaec4f5 commit e4ee29d

File tree

4 files changed

+144
-10
lines changed

4 files changed

+144
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1313

1414
[compat]
1515
Adapt = "4.1.1"
16-
ForwardBackward = "1.0.0"
16+
ForwardBackward = "0.1.0"
1717
Manifolds = "0.10.12"
1818
NNlib = "0.9.27"
1919
OneHotArrays = "0.2.6"

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
# Flowfusion
1+
# Flowfusion.jl
2+
3+
![Image](https://github.com/user-attachments/assets/f2754ba5-b798-4db9-8ce6-a0324b89a534)
24

35
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/Flowfusion.jl/stable/)
46
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Flowfusion.jl/dev/)
@@ -7,7 +9,7 @@
79

810

911

10-
Flowfusion is a Julia package for training and sampling from diffusion and flow matching models (and some things in between), across continuous, discrete, and manifold spaces. It provides a unified framework for:
12+
Flowfusion.jl is a Julia package for training and sampling from diffusion and flow matching models (and some things in between), across continuous, discrete, and manifold spaces, all in a single unified framework and interface.
1113

1214
## Features
1315

examples/discrete.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,19 @@ P = NoisyInterpolatingDiscreteFlow(0.1)
4747
model = DModel(embeddim = 128, l = 2, K = 33, layers = 2)
4848

4949
eta = 0.005
50-
opt_state = Flux.setup(Adam(eta), model)
50+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.0001), model)
5151

52-
iters = 4000
52+
iters = 400
5353
for i in 1:iters
5454
#Set up a batch of training pairs, and t
55-
X1 = DiscreteState(33, sampleX1(n_samples))
56-
X0 = DiscreteState(33, sampleX0(n_samples))
55+
X1 = onehot(DiscreteState(33, sampleX1(n_samples)))
56+
X0 = onehot(DiscreteState(33, sampleX0(n_samples)))
5757
t = rand(T, 1, n_samples)
5858
#Construct the bridge:
59-
Xt = stochastic(Float32, bridge(P, X0, X1, t))
59+
Xt = dense(bridge(P, X0, X1, t)) #Zygote doesn't like the onehot input, so we make it dense.
6060
#Gradient
6161
l,g = Flux.withgradient(model) do m
62-
floss(P, m(t,Xt), onehot(X1), t) #CE loss - Scaling with t doesn't seem critical for this one
62+
floss(P, m(t,Xt), X1, scalefloss(P,t,1)) #I prefer pow = 1 for discrete.
6363
end
6464
#Update
6565
Flux.update!(opt_state, model, g[1])
@@ -72,10 +72,11 @@ for i in 1:iters
7272
end
7373
end
7474

75+
7576
n_inference_samples = 10000
7677
X0 = DiscreteState(33, sampleX0(n_inference_samples))
7778
paths = Tracker()
78-
samp = gen(P, X0, (t,Xt) -> softmax(model(t,onehot(Xt))), 0f0:0.001f0:1f0, tracker = paths) #Note the softmax here
79+
@time samp = gen(P, X0, (t,Xt) -> softmax(model(t,onehot(Xt))), 0f0:0.01f0:1f0, tracker = paths) #<- Note the softmax, and onehot here
7980

8081
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, color = "blue", alpha = 0.4, label = "Initial", size = (400,400), legend = :topleft, xlim = (1,33), ylim = (1,33))
8182
scatter!(samp.state[1,:],samp.state[2,:], msw = 0, color = "green", alpha = 0.04, label = :none)
@@ -89,3 +90,10 @@ plot!([-10],[-10], color = "red", label = "Trajectory", alpha = 0.4)
8990
pl
9091
savefig("discrete_$P.svg")
9192

93+
#=
94+
#Another way to do this is to make the X0 onehot, and then it'll stay onehot through the gen:
95+
X0 = onehot(DiscreteState(33, sampleX0(n_inference_samples)))
96+
paths = Tracker()
97+
@time samp = gen(P, X0, (t,Xt) -> softmax(model(t,Xt)), 0f0:0.01f0:1f0, tracker = paths)
98+
=#
99+

examples/logo_example.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using Revise
4+
5+
#using FileIO, Images
6+
#img = FileIO.load("fflogo.png")
7+
#arr = [a.r for a in img] .< 0.5
8+
#sampinds = (x -> (x[2], 265-x[1])).(CartesianIndices(arr)[arr])
9+
#CSV.write("logoinds.csv", DataFrame(sampinds))
10+
11+
using CSV, DataFrames
12+
df = CSV.read("logoinds.csv", DataFrame)
13+
sampinds = [Tuple(df[i,:]) for i in 1:size(df,1)]
14+
flowinds = [s ./ 200 for s in sampinds if s[2] > 0]
15+
fusioninds = [s ./ 200 for s in sampinds if s[2] <= 0]
16+
17+
Pkg.develop(path="../../ForwardBackward/")
18+
Pkg.develop(path="../")
19+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots, Manifolds
20+
21+
#Set up a Flux model: X̂1 = model(t,Xt)
22+
struct FModel{A}
23+
layers::A
24+
end
25+
Flux.@layer FModel
26+
function FModel(; embeddim = 128, layers = 3)
27+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim))
28+
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 3f0), Dense(embeddim => embeddim))
29+
embed_angle = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim))
30+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
31+
decode = Dense(embeddim => 2)
32+
decode_angle = Dense(embeddim => 1)
33+
layers = (;embed_time, embed_state, embed_angle, ffs, decode, decode_angle)
34+
FModel(layers)
35+
end
36+
37+
function (f::FModel)(t, Xt)
38+
tXt, aXt = tensor.(Xt)
39+
l = f.layers
40+
aenc = vcat(sin.(aXt), cos.(aXt))
41+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
42+
x = l.embed_time(tv) .+ l.embed_state(tXt) .+ l.embed_angle(aenc)
43+
for ff in l.ffs
44+
x = x .+ ff(x)
45+
end
46+
scal = (1.05f0 .- expand(t, ndims(tXt)))
47+
(tXt .+ l.decode(x) .* scal), (l.decode_angle(x) .* scal)
48+
end
49+
50+
T = Float32
51+
n_samples = 1000
52+
M = Torus(1)
53+
ManifoldState(M, Array{Float32}.(rand(M, n_samples)))
54+
sampleX0(n_samples) = ContinuousState(T.(stack(rand(flowinds, n_samples))) .+ rand(T, 2, n_samples) .* 0.01f0), ManifoldState(M, fill([0.6f0], n_samples))
55+
sampleX1(n_samples) = ContinuousState(T.(stack(rand(fusioninds, n_samples))) .+ rand(T, 2, n_samples) .* 0.01f0), ManifoldState(M, fill([-2.54159f0], n_samples))
56+
57+
model = FModel(embeddim = 384, layers = 5)
58+
n_samples = 500
59+
60+
#The process:
61+
P = (BrownianMotion(0.05f0), ManifoldProcess(0.1f0))
62+
63+
#Optimizer:
64+
eta = 0.001
65+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.001), model)
66+
67+
iters = 6000
68+
for i in 1:iters
69+
#Set up a batch of training pairs, and t, where X1 is a MaskedState:
70+
X0 = sampleX0(n_samples)
71+
X1 = sampleX1(n_samples)
72+
t = rand(T, n_samples)
73+
#Construct the bridge:
74+
Xt = bridge(P, X0, X1, t)
75+
ξ = Guide(Xt[2], X1[2])
76+
#Gradient:
77+
l,g = Flux.withgradient(model) do m
78+
hat = m(t,Xt)
79+
floss(P[1], hat[1], X1[1], scalefloss(P[1], t)) + floss(P[2], hat[2], ξ, scalefloss(P[2], t))
80+
end
81+
#Update:
82+
Flux.update!(opt_state, model, g[1])
83+
#Logging, and lr cooldown:
84+
if i % 10 == 0
85+
if i > iters - 2000
86+
eta *= 0.975
87+
Optimisers.adjust!(opt_state, eta)
88+
end
89+
println("i: $i; Loss: $l; eta: $eta")
90+
end
91+
end
92+
93+
function smodel(t, Xt)
94+
hat = model(t,Xt)
95+
return hat[1], Guide(hat[2])
96+
end
97+
98+
#Generate unconditional samples:
99+
n_inference_samples = 5000
100+
X0 = sampleX0(n_inference_samples)
101+
paths = Tracker()
102+
samp = gen(P, X0, smodel, 0f0:0.005f0:1f0, tracker = paths)
103+
104+
cstate = tensor(samp[1])
105+
astate = tensor(samp[2])
106+
zcstate = tensor(X0[1])
107+
zastate = tensor(X0[2])
108+
109+
#scatter(zcstate[1,:], zcstate[2,:], msw = 0, ms = 1.5, markerz = zastate[1,:], cmap = :hsv)
110+
#scatter!(cstate[1,:], cstate[2,:], msw = 0, ms = 1.5, markerz = astate[1,:], cmap = :hsv)
111+
scatter(zcstate[1,:], zcstate[2,:], msw = 0, ms = 1.5, markerz = zastate[1,:], cmap = :hsv, label = :none, xlim = (-0.5, 5.5), ylim = (-1.5, 1.5))
112+
scatter!(cstate[1,:], cstate[2,:], msw = 0, ms = 1.5, markerz = astate[1,:], cmap = :hsv, label = :none, xlim = (-0.5, 5.5), ylim = (-1.5, 1.5))
113+
scatter!([-100,-100],[-100,-100], markerz = [-pi,pi], label = :none, colorbar = :none, axis=([], false))
114+
115+
postraj = stack([tensor(p[1]) for p in paths.xt])
116+
angtraj = stack([tensor(p[2]) for p in paths.xt])
117+
118+
anim = @animate for i vcat([1 for i in 1:20], 1:size(postraj, 3), [size(postraj, 3) for i in 1:20])
119+
scatter(postraj[1,:,i], postraj[2,:,i], msw = 0, ms = 1, markerz = angtraj[1,:,i], cmap = :hsv, label = :none, xlim = (-0.0, 5.2), ylim = (-1.3, 1.3), size = (400, 200))
120+
scatter!([-100,-100],[-100,-100], markerz = [-pi,pi], label = :none, colorbar = :none, axis=([], false))
121+
end
122+
gif(anim, "logo_$(P).mp4", fps = 30)
123+
gif(anim, "logo_$(P).gif", fps = 30)
124+

0 commit comments

Comments
 (0)