Skip to content

Commit 6ff3698

Browse files
committed
Updating readme and example
1 parent 48a6005 commit 6ff3698

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
Flowfusion.jl is a Julia package for training and sampling from diffusion and flow matching models (and some things in between), across continuous, discrete, and manifold spaces, all in a single unified framework and interface.
99

10-
![Image](https://github.com/user-attachments/assets/d739c07e-f9e9-4aef-932e-c36cae182391)
11-
![Image](https://github.com/user-attachments/assets/f2754ba5-b798-4db9-8ce6-a0324b89a534)
10+
![Image](https://github.com/user-attachments/assets/ff7f25e6-441d-4840-ac9c-a849e7b57aa7)
1211

1312
The animated logo shows samples from a model trained to jointly transport a 2D point and an angular hue between two distributions. For the 2D point, the left side uses "Flow matching" with deterministic trajectories, and the right uses a Brownian bridge. For both sides, the angular hue is diffused via an angular Brownian bridge. The hue endpoints are antipodal, and you can see both paths, in opposite angular directions, are sampled.
1413

examples/logo_example.jl

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
using Pkg
22
Pkg.activate(".")
33
using Revise
4+
Pkg.develop(path="../../ForwardBackward/")
5+
Pkg.develop(path="../")
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots, Manifolds
7+
using CSV, DataFrames
48

9+
#Creating the logoinds.csv file:
510
#using FileIO, Images
611
#img = FileIO.load("fflogo.png")
712
#arr = [a.r for a in img] .< 0.5
813
#sampinds = (x -> (x[2], 265-x[1])).(CartesianIndices(arr)[arr])
914
#CSV.write("logoinds.csv", DataFrame(sampinds))
1015

11-
using CSV, DataFrames
12-
df = CSV.read("logoinds.csv", DataFrame)
16+
df = CSV.read("logoinds.csv", DataFrame) #https://github.com/user-attachments/files/18864465/logoinds.csv
1317
sampinds = [Tuple(df[i,:]) for i in 1:size(df,1)]
1418
flowinds = [s ./ 200 for s in sampinds if s[2] > 0]
1519
fusioninds = [s ./ 200 for s in sampinds if s[2] <= 0]
1620

17-
Pkg.develop(path="../../ForwardBackward/")
18-
Pkg.develop(path="../")
19-
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots, Manifolds
20-
2121
#Set up a Flux model: X̂1 = model(t,Xt)
2222
struct FModel{A}
2323
layers::A
@@ -54,18 +54,18 @@ ManifoldState(M, Array{Float32}.(rand(M, n_samples)))
5454
sampleX0(n_samples) = ContinuousState(T.(stack(rand(flowinds, n_samples))) .+ rand(T, 2, n_samples) .* 0.01f0), ManifoldState(M, fill([0.6f0], n_samples))
5555
sampleX1(n_samples) = ContinuousState(T.(stack(rand(fusioninds, n_samples))) .+ rand(T, 2, n_samples) .* 0.01f0), ManifoldState(M, fill([-2.54159f0], n_samples))
5656

57-
model = FModel(embeddim = 512, layers = 5)
57+
model = FModel(embeddim = 384, layers = 4)
5858
n_samples = 500
5959

6060
#The process:
61-
P = (BrownianMotion(0.05f0), ManifoldProcess(0.1f0))
62-
#P = (Deterministic(), ManifoldProcess(0.1f0))
61+
P = (FProcess(BrownianMotion(0.1f0), t -> 1-(1-t)^2), ManifoldProcess(0.1f0))
62+
#P = (FProcess(Deterministic(), t -> 1-(1-t)^2), ManifoldProcess(0.1f0))
6363

6464
#Optimizer:
6565
eta = 0.001
6666
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.001), model)
6767

68-
iters = 10000
68+
iters = 8000
6969
for i in 1:iters
7070
#Set up a batch of training pairs, and t, where X1 is a MaskedState:
7171
X0 = sampleX0(n_samples)
@@ -102,17 +102,8 @@ X0 = sampleX0(n_inference_samples)
102102
paths = Tracker()
103103
samp = gen(P, X0, smodel, 0f0:0.005f0:1f0, tracker = paths)
104104

105-
cstate = tensor(samp[1])
106-
astate = tensor(samp[2])
107-
zcstate = tensor(X0[1])
108-
zastate = tensor(X0[2])
109-
110-
scatter(zcstate[1,:], zcstate[2,:], msw = 0, ms = 1.5, markerz = zastate[1,:], cmap = :hsv, label = :none, xlim = (-0.5, 5.5), ylim = (-1.5, 1.5))
111-
scatter!(cstate[1,:], cstate[2,:], msw = 0, ms = 1.5, markerz = astate[1,:], cmap = :hsv, label = :none, xlim = (-0.5, 5.5), ylim = (-1.5, 1.5))
112-
scatter!([-100,-100],[-100,-100], markerz = [-pi,pi], label = :none, colorbar = :none, axis=([], false))
113-
114-
postraj = stack([tensor(p[1]) for p in paths.xt])
115-
angtraj = stack([tensor(p[2]) for p in paths.xt])
105+
postraj = stack(vcat([tensor(X0[1])], [tensor(p[1]) for p in paths.xt], [tensor(samp[1])]))
106+
angtraj = stack(vcat([tensor(X0[2])], [tensor(p[2]) for p in paths.xt], [tensor(samp[2])]))
116107

117108
anim = @animate for i vcat([1 for i in 1:20], 1:size(postraj, 3), [size(postraj, 3) for i in 1:20])
118109
scatter(postraj[1,:,i], postraj[2,:,i], msw = 0, ms = 1, markerz = angtraj[1,:,i], cmap = :hsv, label = :none, xlim = (-0.0, 5.2), ylim = (-1.3, 1.3), size = (400, 200))
@@ -121,3 +112,21 @@ end
121112
gif(anim, "logo_$(P).mp4", fps = 30)
122113
gif(anim, "logo_$(P).gif", fps = 30)
123114

115+
#=
116+
#To create a side-by-side animation (useful for comparing schedules, noise, etc), run the above once, then store the trajectories, then run it again, and then merge, with an offset:
117+
#diffpostraj = copy(postraj)
118+
#diffangtraj = copy(angtraj)
119+
120+
flowpostraj = copy(postraj)
121+
flowangtraj = copy(angtraj)
122+
123+
mergedpostraj = hcat(diffpostraj .+ reshape([6, 0], 2, 1, 1), flowpostraj)
124+
mergedangtraj = hcat(diffangtraj, flowangtraj)
125+
126+
anim = @animate for i ∈ vcat([1 for i in 1:20], 1:size(mergedpostraj, 3), [size(mergedpostraj, 3) for i in 1:20])
127+
scatter(mergedpostraj[1,:,i], mergedpostraj[2,:,i], msw = 0, ms = 1, markerz = mergedangtraj[1,:,i], cmap = :hsv, label = :none, xlim = (-0.0, 11.2), ylim = (-1.3, 1.3), size = (800, 200))
128+
scatter!([-100,-100],[-100,-100], markerz = [-pi,pi], label = :none, colorbar = :none, axis=([], false))
129+
end
130+
gif(anim, "mergedlogo_$(P).mp4", fps = 30)
131+
gif(anim, "mergedlogo_$(P).gif", fps = 30)
132+
=#

0 commit comments

Comments
 (0)