Skip to content

Commit 8c45376

Browse files
committed
Masking example
1 parent e0e9e9e commit 8c45376

File tree

3 files changed

+162
-4
lines changed

3 files changed

+162
-4
lines changed

README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,61 @@
44
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Flowfusion.jl/dev/)
55
[![Build Status](https://github.com/MurrellGroup/Flowfusion.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Flowfusion.jl/actions/workflows/CI.yml?query=branch%3Amain)
66
[![Coverage](https://codecov.io/gh/MurrellGroup/Flowfusion.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Flowfusion.jl)
7+
8+
9+
10+
Flowfusion is a Julia package for learning and sampling from conditional diffusion processes across continuous, discrete, and manifold spaces. It provides a unified framework for:
11+
12+
- Learning conditional flows between states
13+
- Sampling from learned distributions
14+
- Working with various state types (continuous, discrete, manifold)
15+
- Handling partial observations and masked states
16+
17+
## Features
18+
19+
### Multiple State Types
20+
- Continuous states (Euclidean spaces)
21+
- Discrete states (categorical variables)
22+
- Manifold states (probability simplexes, tori, rotations)
23+
- Masked states for partial conditioning
24+
25+
### Supported Processes
26+
- Deterministic flows
27+
- Brownian motion
28+
- Ornstein-Uhlenbeck
29+
- Discrete flows (InterpolatingDiscreteFlow, NoisyInterpolatingDiscreteFlow)
30+
- Manifold-specific processes
31+
32+
### Core Operations
33+
- `bridge(P, X0, X1, t)`: Sample intermediate states conditioned on start and end states
34+
- `gen(P, X0, model, steps)`: Generate sequences using a learned model
35+
- Support for both direct state prediction and tangent coordinate prediction
36+
37+
### Training
38+
- Loss functions adapted to different state/process types
39+
- Support for masked training (partial observations)
40+
- Time scaling for improved training dynamics
41+
- Integration with Flux.jl for neural network models
42+
43+
## Examples
44+
45+
The package includes several examples demonstrating different use cases:
46+
47+
- `continuous.jl`: Learning flows between clusters in continuous space
48+
- `discrete.jl`: Learning categorical transitions
49+
- `torus.jl`: Learning flows on a torus manifold
50+
- `probabilitysimplex.jl`: Learning flows between probability distributions
51+
52+
## Installation
53+
54+
```julia
55+
using Pkg
56+
Pkg.add("Flowfusion")
57+
```
58+
59+
## Quick Start
60+
61+
```julia
62+
using Flowfusion, Flux
63+
#To do.
64+
```

examples/continuous_masked.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using Revise
4+
Pkg.develop(path="../../ForwardBackward/")
5+
Pkg.develop(path="../")
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
7+
8+
#Set up a Flux model: X̂1 = model(t,Xt)
9+
struct FModel{A}
10+
layers::A
11+
end
12+
Flux.@layer FModel
13+
function FModel(; embeddim = 128, spacedim = 2, layers = 3)
14+
embed_mask = Dense(2 => embeddim) #<- The model should usually know which positions are masked
15+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
16+
embed_state = Chain(RandomFourierFeatures(2 => embeddim, 1f0), Dense(embeddim => embeddim, swish))
17+
ffs = [Dense(embeddim => embeddim, swish) for _ in 1:layers]
18+
decode = Dense(embeddim => spacedim)
19+
layers = (; embed_mask, embed_time, embed_state, ffs, decode)
20+
FModel(layers)
21+
end
22+
23+
function (f::FModel)(t, Xt)
24+
l = f.layers
25+
tXt = tensor(Xt)
26+
tv = zero(tXt[1:1,:]) .+ expand(t, ndims(tXt))
27+
x = l.embed_time(tv) .+ l.embed_state(tXt) .+ l.embed_mask(Xt.cmask) #<- Mask handling
28+
for ff in l.ffs
29+
x = x .+ ff(x)
30+
end
31+
tXt .+ l.decode(x) .* (1.05f0 .- expand(t, ndims(tXt)))
32+
end
33+
34+
model = FModel(embeddim = 256, layers = 3, spacedim = 2)
35+
36+
#Distributions for training:
37+
T = Float32
38+
sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))
39+
sampleX0(n_samples) = rand(T, 2, n_samples) .+ 2
40+
n_samples = 500
41+
42+
#The masking distribution:
43+
X1mask(n_samples) = rand(2, n_samples) .< 0.75
44+
45+
#The process:
46+
P = BrownianMotion(0.2f0)
47+
48+
#Optimizer:
49+
eta = 0.01
50+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.001), model)
51+
52+
iters = 6000
53+
for i in 1:iters
54+
#Set up a batch of training pairs, and t, where X1 is a MaskedState:
55+
X1m = X1mask(n_samples)
56+
X1 = MaskedState(ContinuousState(sampleX1(n_samples)), X1m, X1m)
57+
X0 = ContinuousState(sampleX0(n_samples))
58+
t = rand(T, n_samples)
59+
#Construct the bridge:
60+
Xt = bridge(P, X0, X1, t) #<- Only positions where mask=1 are noised because X1 is a MaskedState
61+
#Gradient:
62+
l,g = Flux.withgradient(model) do m
63+
floss(P, m(t,Xt), X1, scalefloss(P, t)) #<- Only positions where mask=1 are included in the loss
64+
end
65+
#Update:
66+
Flux.update!(opt_state, model, g[1])
67+
#Logging, and lr cooldown:
68+
if i % 10 == 0
69+
if i > iters - 2000
70+
eta *= 0.975
71+
Optimisers.adjust!(opt_state, eta)
72+
end
73+
println("i: $i; Loss: $l; eta: $eta")
74+
end
75+
end
76+
77+
#Generate unconditional samples:
78+
n_inference_samples = 5000
79+
X0m = zeros(2, n_inference_samples) .< Inf #<- A mask with no conditioned (all 1s)
80+
X0 = MaskedState(ContinuousState(sampleX0(n_inference_samples)), X0m, X0m)
81+
paths = Tracker()
82+
samp = gen(P, X0, model, 0f0:0.005f0:1f0, tracker = paths)
83+
84+
#Generate conditional samples, where the mask gets encoded into X0
85+
n_masked_inference_samples = 500
86+
X0m = rand(2, n_masked_inference_samples) .< 0.5
87+
X0m[1,:] .= 1 .- X0m[2,:] #Making sure we don't have any double-masked points
88+
conditionalX0 = MaskedState(ContinuousState(sampleX0(n_masked_inference_samples)), X0m, X0m)
89+
tensor(conditionalX0)[.!(X0m)] .= (rand(2, n_masked_inference_samples) .* 0.1f0 .+ [1f0, -1f0])[.!(X0m)] #<- Condition on these specific values
90+
conditional_samp = gen(P, conditionalX0, model, 0f0:0.005f0:1f0)
91+
92+
#Plotting:
93+
pl = scatter(tensor(X0)[1,:],tensor(X0)[2,:], msw = 0, ms = 1, color = "blue", alpha = 0.5, size = (400,400), legend = :topleft, label = "X0")
94+
X1true = sampleX1(n_inference_samples)
95+
scatter!(X1true[1,:],X1true[2,:], msw = 0, ms = 1, color = "orange", alpha = 0.5, label = "X1 (true)")
96+
scatter!(tensor(samp)[1,:],tensor(samp)[2,:], msw = 0, ms = 1, color = "green", alpha = 0.5, label = "X1 (generated)")
97+
scatter!(tensor(conditional_samp)[1,:],tensor(conditional_samp)[2,:], msw = 0, ms = 1, color = "red", alpha = 0.5, label = "X1 (conditioned)")
98+
display(pl)
99+
savefig("conditioned_cat_$P.svg")

src/mask.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ unmask(X) = X
6464
cmask!(Xt_state, X1_state, cmask)
6565
cmask!(Xt, X1)
6666
67-
Applies, in place, a conditioning mask, forcing elements (or slices) of `Xt` to be equal to `X1`, where `cmask` is 1.
67+
Applies, in place, a conditioning mask, where only elements (or slices) of `Xt` where `cmask` is 1 are noised. When `cmask` is 0, the elements are forced to be equal to `X1`.
6868
"""
6969
function cmask!(Xt_state, X1_state, cmask)
70-
endslices(Xt_state,cmask) .= endslices(X1_state,cmask)
70+
endslices(Xt_state,.!cmask) .= endslices(X1_state,.!cmask)
7171
return Xt_state
7272
end
7373

@@ -90,7 +90,7 @@ cmask!(Xt::Tuple, X1::Tuple) = map(cmask!, Xt, X1)
9090
"""
9191
mask(X, Y)
9292
93-
If `Y` is a `MaskedState`, `mask(X, Y)` returns a `MaskedState` with the content of `X` where elements of `Y.cmask` are 0, and `Y` where `Y.cmask` is 1.
93+
If `Y` is a `MaskedState`, `mask(X, Y)` returns a `MaskedState` with the content of `X` where elements of `Y.cmask` are 1, and `Y` where `Y.cmask` is 0.
9494
`cmask` and `lmask` are inherited from `Y`.
9595
If `Y` is not a `MaskedState`, `mask(X, Y)` returns `X`.
9696
"""
@@ -107,7 +107,8 @@ bridge(P::UProcess, X0, X1::MaskedState, t) = mask(bridge(P, unmask(X0), X1.S, t
107107

108108
#Mask passthroughs, because the masking gets handled elsewhere:
109109
step(P::UProcess, Xₜ::MaskedState, hat, s₁, s₂) = step(P, unmask(Xₜ), unmask(hat), s₁, s₂) #step is only called in gen, which handles the masking itself
110-
resolveprediction(X::MaskedState, Xₜ) = resolveprediction(unmask(X), unmask(Xₜ))
110+
resolveprediction(X, Xₜ) = resolveprediction(unmask(X), unmask(Xₜ))
111+
#resolveprediction(X, Xₜ::MaskedState) = resolveprediction(unmask(X), unmask(Xₜ))
111112

112113

113114

0 commit comments

Comments
 (0)