Skip to content

Commit aad2095

Browse files
committed
touchups
1 parent e86f438 commit aad2095

File tree

5 files changed

+12
-114
lines changed

5 files changed

+12
-114
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
[![Build Status](https://github.com/MurrellGroup/Flowfusion.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Flowfusion.jl/actions/workflows/CI.yml?query=branch%3Amain)
66
[![Coverage](https://codecov.io/gh/MurrellGroup/Flowfusion.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Flowfusion.jl)
77

8-
![Image](https://github.com/user-attachments/assets/f2754ba5-b798-4db9-8ce6-a0324b89a534)
9-
108
Flowfusion.jl is a Julia package for training and sampling from diffusion and flow matching models (and some things in between), across continuous, discrete, and manifold spaces, all in a single unified framework and interface.
119

12-
The animated shows samples from a model trained to steer a coupled 2D Brownian bridge diffusion in space, with an angular diffusion in hue. The hue endpoints are antipodal, and you can see both paths, in opposite angular directions, are sampled.
10+
![Image](https://github.com/user-attachments/assets/d739c07e-f9e9-4aef-932e-c36cae182391)
11+
![Image](https://github.com/user-attachments/assets/f2754ba5-b798-4db9-8ce6-a0324b89a534)
12+
13+
The animated logo shows samples from a model trained to jointly transport a 2D point and an angular hue between two distributions. For the 2D point, the left side uses "Flow matching" with deterministic trajectories, and the right uses a Brownian bridge. For both sides, the angular hue is diffused via an angular Brownian bridge. The hue endpoints are antipodal, and you can see both paths, in opposite angular directions, are sampled.
1314

1415
## Features
1516

examples/logo_example.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,18 @@ ManifoldState(M, Array{Float32}.(rand(M, n_samples)))
5454
sampleX0(n_samples) = ContinuousState(T.(stack(rand(flowinds, n_samples))) .+ rand(T, 2, n_samples) .* 0.01f0), ManifoldState(M, fill([0.6f0], n_samples))
5555
sampleX1(n_samples) = ContinuousState(T.(stack(rand(fusioninds, n_samples))) .+ rand(T, 2, n_samples) .* 0.01f0), ManifoldState(M, fill([-2.54159f0], n_samples))
5656

57-
model = FModel(embeddim = 384, layers = 5)
57+
model = FModel(embeddim = 512, layers = 5)
5858
n_samples = 500
5959

6060
#The process:
6161
P = (BrownianMotion(0.05f0), ManifoldProcess(0.1f0))
62+
#P = (Deterministic(), ManifoldProcess(0.1f0))
6263

6364
#Optimizer:
6465
eta = 0.001
6566
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.001), model)
6667

67-
iters = 6000
68+
iters = 10000
6869
for i in 1:iters
6970
#Set up a batch of training pairs, and t, where X1 is a MaskedState:
7071
X0 = sampleX0(n_samples)
@@ -82,8 +83,8 @@ for i in 1:iters
8283
Flux.update!(opt_state, model, g[1])
8384
#Logging, and lr cooldown:
8485
if i % 10 == 0
85-
if i > iters - 2000
86-
eta *= 0.975
86+
if i > iters - 3000
87+
eta *= 0.98
8788
Optimisers.adjust!(opt_state, eta)
8889
end
8990
println("i: $i; Loss: $l; eta: $eta")
@@ -106,8 +107,6 @@ astate = tensor(samp[2])
106107
zcstate = tensor(X0[1])
107108
zastate = tensor(X0[2])
108109

109-
#scatter(zcstate[1,:], zcstate[2,:], msw = 0, ms = 1.5, markerz = zastate[1,:], cmap = :hsv)
110-
#scatter!(cstate[1,:], cstate[2,:], msw = 0, ms = 1.5, markerz = astate[1,:], cmap = :hsv)
111110
scatter(zcstate[1,:], zcstate[2,:], msw = 0, ms = 1.5, markerz = zastate[1,:], cmap = :hsv, label = :none, xlim = (-0.5, 5.5), ylim = (-1.5, 1.5))
112111
scatter!(cstate[1,:], cstate[2,:], msw = 0, ms = 1.5, markerz = astate[1,:], cmap = :hsv, label = :none, xlim = (-0.5, 5.5), ylim = (-1.5, 1.5))
113112
scatter!([-100,-100],[-100,-100], markerz = [-pi,pi], label = :none, colorbar = :none, axis=([], false))

src/bridge.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Returns a state where `X.state` is not onehot.
3232
unhot(X::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = DiscreteState(X.K, onecold(X.state, 1:X.K))
3333
unhot(X::DiscreteState{<:AbstractArray{<:Integer}}) = X
3434
ForwardBackward.stochastic(T::Type, o::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = CategoricalLikelihood(T.(o.state .+ 0), zeros(T, size(o.state)[2:end]...))
35-
#TODO: onehot/unhot for masked state?
3635

3736
"""
3837
dense(X::DiscreteState; T = Float32)
@@ -138,9 +137,8 @@ function stack_tracker(tracker, field; tuple_index = 1)
138137
end
139138

140139

141-
140+
#Todo: tesst Guide with MaskedState
142141
Guide(Xt::ManifoldState, X1::ManifoldState; kwargs...) = Guide(tangent_guide(Xt, X1; kwargs...))
143-
#MaskedState needs to be tested. Current setup disallows X1 being masked but Xt not.
144142
Guide(mXt::Union{MaskedState{<:ManifoldState}, ManifoldState}, mX1::MaskedState{<:ManifoldState}; kwargs...) = Guide(tangent_guide(mXt, mX1; kwargs...), mX1.cmask, mX1.lmask)
145143

146144
#=If we want the model to directly predict the tangent coordinates, we use:

src/loss.jl

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,3 @@
1-
#=######
2-
NOTES on what works:
3-
- Euclidean state:
4-
- - any compatible process, using floss
5-
- Manifold state:
6-
- - any compatible process, using tcloss
7-
- Discrete state:
8-
- - for a DiscreteProcess, only UniformUnmasking works properly. The rest have issues.
9-
- - works to using the ProbabilitySimplex in a ManifoldProcess.
10-
- - Either:
11-
- - - The process must have non-zero variance
12-
- - - or X0 must be a continuous distribution (ie. not discrete "corners") on the ProbabilitySimplex (in which case a deterministic process also works)
13-
=#######
14-
15-
#This is badness that doesn't work:
16-
#rotangle(rots::AbstractArray{T,3}) where T = acos.(clamp.((rots[1,1,:] .+ rots[2,2,:] .+ rots[3,3,:] .- 1) ./ 2, T(-0.99), T(0.99)))
17-
#rotangle(rots::AbstractArray) = reshape(rotangle(reshape(rots, 3, 3, :)), 1, size(rots)[3:end]...)
18-
#torangle(x, y) = mod.(y .- x .+ π, 2π) .- π
19-
#msra(X̂₁, X₁) = rotangle(batched_mul(batched_transpose(tensor(X̂₁)), tensor(X₁))).^2 #Mean Squared Angle
20-
#msta(X̂₁, X₁) = sum(torangle(tensor(X̂₁), tensor(X₁)), dims=1).^2 #Mean Squared Toroidal Angle
21-
221
mse(X̂₁, X₁) = abs2.(tensor(X̂₁) .- tensor(X₁)) #Mean Squared Error
232
lce(X̂₁, X₁) = -sum(tensor(X₁) .* logsoftmax(tensor(X̂₁)), dims=1) #Logit Cross Entropy
243
kl(P,Q) = sum(softmax(tensor(P)) .* (logsoftmax(tensor(P)) .- log.(tensor(Q))), dims=1) #Kullback-Leibler Divergence
@@ -55,45 +34,11 @@ floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState),
5534
#For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu.
5635
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) = error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.")
5736
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) = scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁))
58-
#floss(P::fbu(ManifoldProcess{Rotations(3)}), X̂₁, X₁::msu(ManifoldState{Rotations(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁))
59-
#floss(P::fbu(ManifoldProcess{SpecialOrthogonal(3)}), X̂₁, X₁::msu(ManifoldState{SpecialOrthogonal(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁))
60-
#floss(P::fbu(ManifoldProcess), X̂₁, X₁::msu(ManifoldState{<:Torus}), c) = scaledmaskedmean(msta(X̂₁, X₁), c, getlmask(X₁))
61-
6237
floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Union{AbstractArray, Real}) = sum(floss.(P, X̂₁, X₁, (c,)))
6338
floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Tuple) = sum(floss.(P, X̂₁, X₁, c))
64-
65-
#I should make a self-balancing loss that tracks the running mean/std and adaptively scales to balance against target weights.
66-
67-
"""
68-
tcloss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ, c, mask = nothing)
69-
70-
Where `ξhat` is the predicted tangent coordinates, and `ξ` is the true tangent coordinates.
71-
"""
7239
floss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ::Guide, c) = scaledmaskedmean(mse(ξhat, ξ.H), c, getlmask(ξ))
73-
#tcloss(P::fbu(DiscreteProcess), ξhat, ξ, c, mask = nothing) = scaledmaskedmean(rkl(ξhat, ξ), c, mask)
74-
75-
76-
77-
78-
#=
79-
#Doesn't help to do it this way
80-
"""
81-
tangent_coordinates(P::DiscreteProcess, Xt::DiscreteState, X1)
82-
83-
Computes (a weighted mixture of) Doob's h-transform(s) that would condition the current state Xt (which must be a discrete value)
84-
to end at X1 (which can be a distribution) under P. Maybe.
85-
"""
86-
function tangent_coordinates(P::DiscreteProcess, X1::DiscreteState, t)
87-
#(for a single column) for state=i at 1-t, H_j(t)/H_i(t) is the rate scaling ratio per Doob's h-transform.
88-
#If the model can learn this directly, we can gen.
89-
H = backward(X1, P, 1 .- t)
90-
scale = sum(H.dist, dims = 1)
91-
H.dist ./= scale
92-
return H
93-
end
94-
=#
95-
9640

41+
#I should make a self-balancing loss that tracks the running mean/std and adaptively scales to balance against target weights.
9742

9843
########################################################################
9944
#Manifold-specific helper functions

src/mask.jl

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
#Need to test key mask functions on ContinuousState, DiscreteState, CategoricalLikelihood, ManifoldState
2-
31
#=#####################
42
Conditioning mask behavior:
53
The typical use is that it makes sense, during training, to construct the conditioning mask on the training observation, X1.
64
During inference, the conditioning mask (and conditioned-upon state) has to be present on X1.
75
This dictates the behavior of the masking:
8-
- When bridge() is called, the mask, and the state where mask=1, are inherited from X1.
6+
- When bridge() is called, the mask, and the state where mask=0, are inherited from X1.
97
- When gen is called, the state and mask will be propogated from X0 through all of the Xts.
108
=#####################
119

1210

13-
#import Base.copy
1411
ForwardBackward.tensor(X::MaskedState) = tensor(X.S)
1512
Base.copy(X::MaskedState) = MaskedState(copy(X.S), copy(X.cmask), copy(X.lmask))
1613

@@ -22,21 +19,9 @@ For example, if `m` is a boolean array, then `size(a)[ndims(a)-ndims(m):end] ==
2219
"""
2320
endslices(a,m) = @view a[ntuple(Returns(:),ndims(a)-ndims(m))...,m]
2421

25-
26-
#=
27-
Need to handle:
28-
Xt = stochastic(Float32, bridge(P, X0, X1, t))
29-
which means "stochastic" needs to preserve mask, and CategoricalLikelihoods need to be able to be masked.
30-
onehot too
31-
=#
32-
33-
34-
3522
onehot(X::MaskedState{<:DiscreteState{<:AbstractArray{<:Integer}}}) = MaskedState(onehot(X.S), X.cmask, X.lmask)
3623
ForwardBackward.stochastic(T::Type, o::MaskedState) = MaskedState(stochastic(T, o.S), o.cmask, o.lmask)
3724

38-
39-
4025
getlmask(X1::UState) = X1.lmask
4126
getlmask(X1::State) = nothing
4227
getcmask(X1::UState) = X1.cmask
@@ -73,32 +58,8 @@ function cmask!(Xt_state, X1_state, cmask)
7358
return Xt_state
7459
end
7560

76-
#THIS IS NOT MODIFYING - NEED TO RETHINK
77-
#=
78-
function cmask!(ohXt_state::Union{OneHotArray, OneHotMatrix}, ohX1_state::Union{OneHotArray, OneHotMatrix}, cmask)
79-
K = size(ohXt_state, 1)
80-
Xt_state, X1_state = onecold(ohXt_state, 1:K), onecold(ohX1_state, 1:K)
81-
endslices(Xt_state,.!cmask) .= endslices(X1_state,.!cmask)
82-
return onehotbatch(Xt_state, 1:K)
83-
end
84-
=#
85-
8661
cmask!(Xt::Union{State, MaskedState{<:State}}, X1::MaskedState{<:StateLikelihood}) = error("Cannot condition a state on a Likelihood")
8762

88-
#=
89-
function cmask!(Xt::MaskedState, X1::MaskedState)
90-
cmask!(Xt.S.state, X1.S.state, X1.cmask)
91-
size(Xt.cmask) != size(X1.cmask) && error("cmask dimensions must match")
92-
Xt.cmask .= X1.cmask
93-
return Xt
94-
end
95-
96-
cmask!(Xt_state, X1_state, cmask::Nothing) = Xt_state
97-
cmask!(Xt, X1::State) = Xt
98-
cmask!(Xt, X1::StateLikelihood) = Xt
99-
cmask!(Xt::Tuple, X1::Tuple) = map(cmask!, Xt, X1)
100-
=#
101-
10263
"""
10364
mask(X, Y)
10465
@@ -126,10 +87,4 @@ bridge(P::UProcess, X0, X1::MaskedState, t) = mask(bridge(P, unmask(X0), X1.S, t
12687
#Mask passthroughs, because the masking gets handled elsewhere:
12788
step(P::UProcess, Xₜ::MaskedState, hat, s₁, s₂) = step(P, unmask(Xₜ), unmask(hat), s₁, s₂) #step is only called in gen, which handles the masking itself
12889
resolveprediction(X, Xₜ) = resolveprediction(unmask(X), unmask(Xₜ))
129-
#resolveprediction(X, Xₜ::MaskedState) = resolveprediction(unmask(X), unmask(Xₜ))
130-
131-
print("?")
13290

133-
#REMOVE?:
134-
#⊙ itself doesn't force the masks - it just propogates them. The forcing happens elsewhere.
135-
#ForwardBackward.:⊙(a::MaskedState, b::MaskedState; kwargs...) = MaskedState(⊙(a.S, b.S; kwargs...), a.cmask .* b.cmask, a.lmask .* b.lmask)

0 commit comments

Comments
 (0)