Skip to content

Commit 599642a

Browse files
committed
Working
1 parent 8c45376 commit 599642a

File tree

10 files changed

+330
-70
lines changed

10 files changed

+330
-70
lines changed

README.md

Lines changed: 108 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,58 +7,129 @@
77

88

99

10-
Flowfusion is a Julia package for learning and sampling from conditional diffusion processes across continuous, discrete, and manifold spaces. It provides a unified framework for:
11-
12-
- Learning conditional flows between states
13-
- Sampling from learned distributions
14-
- Working with various state types (continuous, discrete, manifold)
15-
- Handling partial observations and masked states
10+
Flowfusion is a Julia package for training and sampling from diffusion and flow matching models (and some things in between), across continuous, discrete, and manifold spaces. It provides a unified framework for:
1611

1712
## Features
1813

19-
### Multiple State Types
20-
- Continuous states (Euclidean spaces)
21-
- Discrete states (categorical variables)
22-
- Manifold states (probability simplexes, tori, rotations)
23-
- Masked states for partial conditioning
24-
25-
### Supported Processes
26-
- Deterministic flows
27-
- Brownian motion
28-
- Ornstein-Uhlenbeck
29-
- Discrete flows (InterpolatingDiscreteFlow, NoisyInterpolatingDiscreteFlow)
30-
- Manifold-specific processes
31-
32-
### Core Operations
33-
- `bridge(P, X0, X1, t)`: Sample intermediate states conditioned on start and end states
34-
- `gen(P, X0, model, steps)`: Generate sequences using a learned model
35-
- Support for both direct state prediction and tangent coordinate prediction
14+
- Controllable noise (or fully deterministic for flow matching)
15+
- Flexible initial $X_0$ distribution
16+
- Conditioning via masking
17+
- States: Continuous, discrete, and a wide variety of manifolds supported (via [Manifolds.jl](https://github.com/JuliaManifolds/Manifolds.jl))
18+
- Compound states supported (e.g. jointly sampling from both continuous and discrete variables)
3619

37-
### Training
38-
- Loss functions adapted to different state/process types
39-
- Support for masked training (partial observations)
40-
- Time scaling for improved training dynamics
41-
- Integration with Flux.jl for neural network models
20+
### Basic idea:
21+
- Generate `X0` and `X1` states from your favorite distribution, and a random `t` between 0 and 1
22+
- `Xt = bridge(P, X0, X1, t)`: Sample intermediate states conditioned on start and end states
23+
- Train model to predict how to get to `X1` from `Xt`
24+
- `gen(P, X0, model, steps)`: Generate sequences using a learned model
4225

4326
## Examples
4427

4528
The package includes several examples demonstrating different use cases:
4629

47-
- `continuous.jl`: Learning flows between clusters in continuous space
48-
- `discrete.jl`: Learning categorical transitions
49-
- `torus.jl`: Learning flows on a torus manifold
50-
- `probabilitysimplex.jl`: Learning flows between probability distributions
30+
- `continuous.jl`: Learning a continuous distribution
31+
- `torus.jl`: Continous distributions on a manifold
32+
- `discrete.jl`: Discrete distributions with discrete processes
33+
- `probabilitysimplex.jl`: Discrete distributions with continuous processes via a probability simplex manifold
34+
- `continuous_masked.jl`: Conditioning on partial observations
35+
- `masked_tuple.jl`: Jointly sampling from continuous and discrete variables, with conditioning
5136

5237
## Installation
5338

5439
```julia
55-
using Pkg
56-
Pkg.add("Flowfusion")
40+
]add https://github.com/MurrellGroup/Flowfusion.jl
5741
```
5842

59-
## Quick Start
43+
## A full example
6044

6145
```julia
62-
using Flowfusion, Flux
63-
#To do.
46+
using ForwardBackward, Flowfusion, Flux, RandomFeatureMaps, Optimisers, Plots
47+
48+
#Set up a Flux model: X̂1 = model(t,Xt)
49+
struct FModel{A}
50+
layers::A
51+
end
52+
Flux.@layer FModel
53+
function FModel(; embeddim = 128, spacedim = 2, layers = 3)
54+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
55+
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
56+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
57+
decode = Dense(embeddim => spacedim)
58+
layers = (; embed_time, embed_state, ffs, decode)
59+
FModel(layers)
60+
end
61+
function (f::FModel)(t, Xt)
62+
l = f.layers
63+
tXt = tensor(Xt)
64+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
65+
x = l.embed_time(tv) .+ l.embed_state(tXt)
66+
for ff in l.ffs
67+
x = x .+ ff(x)
68+
end
69+
tXt .+ l.decode(x) .* (1.05f0 .- expand(t, ndims(tXt)))
70+
end
71+
72+
model = FModel(embeddim = 256, layers = 3, spacedim = 2)
73+
74+
#Distributions for training:
75+
T = Float32
76+
sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2
77+
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
78+
n_samples = 400
79+
80+
#The process:
81+
P = BrownianMotion(0.15f0)
82+
#P = Deterministic()
83+
84+
#Optimizer:
85+
eta = 0.001
86+
opt_state = Flux.setup(AdamW(eta = eta), model)
87+
88+
iters = 4000
89+
for i in 1:iters
90+
#Set up a batch of training pairs, and t:
91+
X0 = ContinuousState(sampleX0(n_samples))
92+
X1 = ContinuousState(sampleX1(n_samples))
93+
t = rand(T, n_samples)
94+
#Construct the bridge:
95+
Xt = bridge(P, X0, X1, t)
96+
#Gradient & update:
97+
l,g = Flux.withgradient(model) do m
98+
floss(P, m(t,Xt), X1, scalefloss(P, t))
99+
end
100+
Flux.update!(opt_state, model, g[1])
101+
(i % 10 == 0) && println("i: $i; Loss: $l")
102+
end
103+
104+
#Generate samples by stepping from X0
105+
n_inference_samples = 5000
106+
X0 = ContinuousState(sampleX0(n_inference_samples))
107+
samples = gen(P, X0, model, 0f0:0.005f0:1f0)
108+
109+
#Plotting
110+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
111+
X1true = sampleX1(n_inference_samples)
112+
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
113+
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
64114
```
115+
116+
## Tracking trajectories
117+
118+
```julia
119+
#Generate samples by stepping from X0
120+
n_inference_samples = 5000
121+
X0 = ContinuousState(sampleX0(n_inference_samples))
122+
paths = Tracker() #<- A tracker to record the trajectory
123+
samples = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
124+
125+
#Plotting:
126+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
127+
tvec = stack_tracker(paths, :t)
128+
xttraj = stack_tracker(paths, :xt)
129+
for i in 1:50:1000
130+
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = i==1 ? "Trajectory" : :none, alpha = 0.4)
131+
end
132+
X1true = sampleX1(n_inference_samples)
133+
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
134+
scatter!(samples.state[1,:],samples.state[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
135+
```

examples/masked_tuple.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using Revise
4+
Pkg.develop(path="../../ForwardBackward/")
5+
Pkg.develop(path="../")
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
7+
8+
#Set up a Flux model: X̂1 = model(t,Xt)
9+
struct FModel{A}
10+
layers::A
11+
end
12+
Flux.@layer FModel
13+
function FModel(; embeddim = 128, spacedim = 2, layers = 3)
14+
embed_mask = Dense(2 => embeddim) #<- The model should usually know which positions are masked
15+
embed_mask_discrete = Dense(1 => embeddim)
16+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
17+
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
18+
embed_discrete_state = Dense(4 => embeddim)
19+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
20+
decode = Dense(embeddim => spacedim)
21+
decode_discrete = Dense(embeddim => 4)
22+
layers = (; embed_mask, embed_mask_discrete, embed_time, embed_state, embed_discrete_state, ffs, decode, decode_discrete)
23+
FModel(layers)
24+
end
25+
26+
function (f::FModel)(t, Xt)
27+
cXt, dXt = Xt
28+
l = f.layers
29+
tXt = tensor(cXt)
30+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
31+
x = l.embed_time(tv) .+ l.embed_state(tXt) .+ l.embed_mask(cXt.cmask) .+ l.embed_mask_discrete(reshape(dXt.cmask, 1, :)) .+ l.embed_discrete_state(tensor(dXt)) #<- Mask handling
32+
for ff in l.ffs
33+
x = x .+ ff(x)
34+
end
35+
scal = (1.05f0 .- expand(t, ndims(tXt)))
36+
(tXt .+ l.decode(x) .* scal), (l.decode_discrete(x) .* scal)
37+
end
38+
39+
model = FModel(embeddim = 384, layers = 3, spacedim = 2)
40+
41+
#Distributions for training:
42+
T = Float32
43+
function sampleX1(n_samples)
44+
cstate = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
45+
dstate = ones(Int64, n_samples)
46+
dstate[:] .+= cstate[1,:] .> 0
47+
dstate[:] .+= cstate[2,:] .> 0
48+
return cstate, dstate
49+
end
50+
sampleX0(n_samples) = (rand(T, 2, n_samples) .+ 2), rand(1:4, n_samples)
51+
n_samples = 500
52+
53+
#The masking distribution - we'll only mask the continuous part of the state
54+
X1mask(n_samples) = rand(2, n_samples) .< 0.75
55+
56+
#The process:
57+
P = (BrownianMotion(0.4f0), NoisyInterpolatingDiscreteFlow(0.1))
58+
59+
#Optimizer:
60+
eta = 0.01
61+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.0001), model)
62+
63+
iters = 6000
64+
for i in 1:iters
65+
#Set up a batch of training pairs, and t, where X1 is a MaskedState:
66+
X1cm = X1mask(n_samples)
67+
X1dm = rand(n_samples) .< 0.33
68+
x1 = sampleX1(n_samples)
69+
X1 = (MaskedState(ContinuousState(x1[1]), X1cm, X1cm), MaskedState(onehot(DiscreteState(4, x1[2])), X1dm, X1dm))
70+
x0 = sampleX0(n_samples)
71+
X0 = (ContinuousState(x0[1]), onehot(DiscreteState(4, x0[2])))
72+
t = rand(T, n_samples)
73+
#Construct the bridge:
74+
Xt = bridge(P, X0, X1, t) #<- Only positions where mask=1 are noised because X1 is a MaskedState
75+
#Gradient:
76+
l,g = Flux.withgradient(model) do m
77+
floss(P, m(t,Xt), X1, scalefloss(P, t))
78+
end
79+
#Update:
80+
Flux.update!(opt_state, model, g[1])
81+
#Logging, and lr cooldown:
82+
if i % 10 == 0
83+
if i > iters - 2000
84+
eta *= 0.975
85+
Optimisers.adjust!(opt_state, eta)
86+
end
87+
println("i: $i; Loss: $l; eta: $eta")
88+
end
89+
end
90+
91+
#Generate unconditional samples:
92+
n_inference_samples = 5000
93+
pl = plot(;size = (400,400), legend = :topleft)
94+
Xcm = rand(2, n_inference_samples) .> 0
95+
Xdm = rand(n_inference_samples) .> 0
96+
x0 = sampleX0(n_inference_samples)
97+
X0 = (MaskedState(ContinuousState(x0[1]), Xcm, Xcm), MaskedState(onehot(DiscreteState(4, x0[2])), Xdm, Xdm))
98+
paths = Tracker()
99+
samp = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
100+
cstate = tensor(samp[1])
101+
dstate = tensor(unhot(samp[2].S))
102+
scatter!(cstate[1,:],cstate[2,:], markerz = dstate, cmap = :brg, msw = 0, ms = 1, alpha = 0.3, label = :none, colorbar = :none)
103+
savefig("tuple_cat_unconditioned_$P.svg")
104+
105+
#Generate conditioned samples:
106+
n_inference_samples = 2000
107+
pl = plot(;size = (400,400), legend = :topleft)
108+
for dval in [1, 2, 3]
109+
Xcm = rand(2, n_inference_samples) .> 0
110+
Xdm = rand(n_inference_samples) .< 0
111+
x0 = sampleX0(n_inference_samples)
112+
x0[2] .= dval
113+
X0 = (MaskedState(ContinuousState(x0[1]), Xcm, Xcm), MaskedState(onehot(DiscreteState(4, x0[2])), Xdm, Xdm))
114+
paths = Tracker()
115+
samp = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
116+
cstate = tensor(samp[1])
117+
dstate = tensor(unhot(samp[2].S))
118+
scatter!(cstate[1,:],cstate[2,:], msw = 0, ms = 1, alpha = 0.3, label = "D = $dval")
119+
end
120+
pl
121+
savefig("tuple_cat_conditioned_$P.svg")

examples/probabilitysimplex.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ for i in 1:iters
5555
#Construct the bridge:
5656
Xt = bridge(P, X0, X1, t)
5757
#Get the Xt->X1 tangent coordinates:
58-
ξ = tangent_guide(Xt, X1)
58+
ξ = Guide(Xt, X1)
5959
#Gradient:
6060
l,g = Flux.withgradient(model) do m
61-
tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
61+
floss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
6262
end
6363
#Update:
6464
Flux.update!(opt_state, model, g[1])

examples/torus.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ for i in 1:iters
5454
#Construct the bridge:
5555
Xt = bridge(P, X0, X1, t)
5656
#Compute the tangent coordinates:
57-
ξ = tangent_guide(Xt, X1)
57+
ξ = Guide(Xt, X1)
5858
#Gradient
5959
l,g = Flux.withgradient(model) do m
60-
tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
60+
#tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t)) #GOING TO HAVE TO ADD GUIDE HERE, AND CHANGE IT TO FLOSS
61+
floss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t)) #GOING TO HAVE TO ADD GUIDE HERE, AND CHANGE IT TO FLOSS
6162
end
6263
#Update
6364
Flux.update!(opt_state, model, g[1])

src/Flowfusion.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Need to test/do:
33
Urgent:
44
- Test tuples!
5+
- Test Manifolds with masking (especially tangent_guide and apply_tangent etc)
56
- Masking (cmask) on all state types for bridge and gen
67
- Masking (lmask) on all state types for both losses
78
- tensor on masked states
@@ -42,6 +43,7 @@ export
4243
Tracker,
4344
stack_tracker,
4445
onehot,
46+
unhot,
4547
FProcess,
4648
floss,
4749
tcloss,

0 commit comments

Comments
 (0)