Skip to content

Commit e0e9e9e

Browse files
committed
Masking and tests
1 parent 72546b4 commit e0e9e9e

File tree

9 files changed

+1333
-140
lines changed

9 files changed

+1333
-140
lines changed

examples/examples.jl

Lines changed: 1002 additions & 0 deletions
Large diffs are not rendered by default.

examples/probabilitysimplex.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function (f::PSModel)(t, Xt)
3232
return reshape(l.decode(x), :, 2, len) .* (1.05f0 .- expand(t, 3))
3333
end
3434

35-
model = PSModel(embeddim = 256, l = 2, K = 33, layers = 3)
35+
model = PSModel(embeddim = 128, l = 2, K = 33, layers = 2)
3636

3737
sampleX1(n_samples) = Flowfusion.random_discrete_cat(n_samples)
3838
sampleX0(n_samples) = rand(25:32, 2, n_samples)
@@ -44,7 +44,7 @@ M = ProbabilitySimplex(32)
4444
P = ManifoldProcess(0.5f0)
4545

4646
eta = 0.01
47-
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.01), model)
47+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.0001), model)
4848

4949
iters = 5000
5050
for i in 1:iters
@@ -55,7 +55,7 @@ for i in 1:iters
5555
#Construct the bridge:
5656
Xt = bridge(P, X0, X1, t)
5757
#Get the Xt->X1 tangent coordinates:
58-
ξ = Flowfusion.tangent_coordinates(Xt, X1)
58+
ξ = tangent_guide(Xt, X1)
5959
#Gradient:
6060
l,g = Flux.withgradient(model) do m
6161
tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
@@ -76,7 +76,7 @@ end
7676
n_inference_samples = 5000
7777
X0 = ManifoldState(T, M, sampleX0(n_inference_samples));
7878
paths = Tracker()
79-
X1pred = (t,Xt) -> apply_tangent_coordinates(Xt, model(t,tensor(Xt)))
79+
X1pred = (t,Xt) -> Guide(model(t,tensor(Xt)))
8080
samp = gen(P, X0, X1pred, 0f0:0.002f0:1f0, tracker = paths)
8181

8282
#Plot the X0 and generated X1:

examples/torus.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function (f::TModel)(t, Xt)
3131
return (l.decode(x) .* (1.05f0 .- tv))
3232
end
3333

34-
model = TModel(embeddim = 256, layers = 3, spacedim = 2)
34+
model = TModel(embeddim = 128, layers = 3, spacedim = 2)
3535

3636
T = Float32
3737
sampleX0(n_samples) = rand(T, 2, n_samples) .+ [2.1f0, 1]
@@ -45,7 +45,7 @@ P = ManifoldProcess(0.2f0)
4545
eta = 0.01
4646
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.00001), model)
4747

48-
iters = 8000
48+
iters = 4000
4949
for i in 1:iters
5050
#Set up a batch of training pairs, and t
5151
X1 = ManifoldState(M, eachcol(sampleX1(n_samples))) #Note: eachcol
@@ -54,7 +54,7 @@ for i in 1:iters
5454
#Construct the bridge:
5555
Xt = bridge(P, X0, X1, t)
5656
#Compute the tangent coordinates:
57-
ξ = Flowfusion.tangent_coordinates(Xt, X1)
57+
ξ = tangent_guide(Xt, X1)
5858
#Gradient
5959
l,g = Flux.withgradient(model) do m
6060
tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
@@ -63,7 +63,7 @@ for i in 1:iters
6363
Flux.update!(opt_state, model, g[1])
6464
#Logging, and lr cooldown:
6565
if i % 10 == 0
66-
if i > iters - 3000
66+
if i > iters - 2000
6767
eta *= 0.975
6868
Optimisers.adjust!(opt_state, eta)
6969
end
@@ -76,7 +76,7 @@ n_inference_samples = 2000
7676
X0 = ManifoldState(M, eachcol(sampleX0(n_inference_samples)))
7777
paths = Tracker()
7878
#We wrap the model, because it was predicting tangent coordinates, not the actual state:
79-
X1pred = (t,Xt) -> BackwardGuide(model(t,tensor(Xt)))
79+
X1pred = (t,Xt) -> Guide(model(t,tensor(Xt)))
8080
samp = gen(P, X0, X1pred, 0f0:0.002f0:1f0, tracker = paths)
8181

8282
#Plot the torus, with samples, and trajectories:

src/Flowfusion.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
1+
#=
2+
Need to test/do:
3+
Urgent:
4+
- Test tuples!
5+
- Masking (cmask) on all state types for bridge and gen
6+
- Masking (lmask) on all state types for both losses
7+
- tensor on masked states
8+
- FProcess and whether it matches the target where allowed. Need to come up with a policy on using FProcess with InterpolatingDiscreteProcesses
9+
- X1 pred for rotations (add angle/axis loss back in just because yolo)
10+
- self-conditioning
11+
- GPU use of all state types
12+
Later:
13+
- Make a table of Manifolds where you test whether the key functions are defined, with checkboxes and timing for diffusion and flow.
14+
- Make a table of commands for key types of diffusion/flow. Columns for Process, X0/X1 setup, Xt bridge, loss, gen where things like softmax, Guide, etc are clear.
15+
- Compute probability velocities for UniformDiscrete and PiQ so these can flow.
16+
=#
17+
18+
19+
20+
21+
122
module Flowfusion
223

324
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib
425

26+
include("types.jl")
27+
include("mask.jl")
528
include("bridge.jl")
629
include("loss.jl")
730
include("processes.jl")
@@ -11,18 +34,18 @@ export
1134
InterpolatingDiscreteFlow,
1235
NoisyInterpolatingDiscreteFlow,
1336
MaskedState,
37+
Guide,
38+
tangent_guide,
1439
bridge,
1540
scalefloss,
1641
gen,
1742
Tracker,
1843
stack_tracker,
1944
onehot,
20-
FProcess,
21-
tangent_coordinates,
22-
BackwardGuide,
23-
apply_tangent_coordinates,
45+
FProcess,
2446
floss,
25-
tcloss
47+
tcloss,
48+
dense
2649

2750

2851
#Useful for demos etc:

src/bridge.jl

Lines changed: 55 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,96 +5,39 @@ Assumptions:
55
- Default sampling steps are FProcess.F(t) with even t intervals [NOTE TO SELF: Intervals should be F(t2)-F(t1)]
66
=#####################
77

8-
struct FProcess{A,B}
9-
P::A #Process
10-
F::B #Time transform
11-
end
12-
13-
UProcess = Union{Process,FProcess}
148
process(P::FProcess) = P.P
159
process(P::Process) = P
1610

1711
tscale(P::Process, t) = t
1812
tscale(P::FProcess, t) = P.F.(t)
1913

20-
#=#####################
21-
Conditioning mask behavior:
22-
The typical use is that it makes sense, during training, to construct the conditioning mask on the training observation, X1.
23-
During inference, the conditioning mask (and conditioned-upon state) has to be present on X1.
24-
This dictates the behavior of the masking:
25-
- When bridge() is called, the mask, and the state where mask=1, are inherited from X1.
26-
- When gen is called, the state and mask will be propogated from X0 through all of the Xts.
27-
=#####################
28-
struct MaskedState{A,B,C}
29-
S::A #State
30-
cmask::B #Conditioning mask. 1 = Xt=X1
31-
lmask::C #Loss mask. 1 = included in loss
32-
end
33-
34-
#For when we want to predict the transitions instead of X1hat
35-
struct BackwardGuide{A}
36-
H::A
37-
end
38-
ForwardBackward.:(a::CategoricalLikelihood, b::BackwardGuide) = (a,copytensor!(copy(a),b.H))
39-
40-
#⊙ itself doesn't force the masks - it just propogates them. The forcing happens elsewhere.
41-
ForwardBackward.:(a::MaskedState, b::MaskedState; kwargs...) = MaskedState((a.S, b.S; kwargs...), a.cmask .* b.cmask, a.lmask .* b.lmask)
42-
4314
Adapt.adapt_structure(to, S::ForwardBackward.DiscreteState) = ForwardBackward.DiscreteState(S.K, Adapt.adapt(to, S.state))
4415
Adapt.adapt_structure(to, S::ForwardBackward.ContinuousState) = ForwardBackward.ContinuousState(Adapt.adapt(to, S.state))
4516
Adapt.adapt_structure(to, S::ForwardBackward.CategoricalLikelihood) = ForwardBackward.CategoricalLikelihood(Adapt.adapt(to, S.dist), Adapt.adapt(to, S.log_norm_const))
46-
Adapt.adapt_structure(to, MS::MaskedState{<:State}) = MaskedState(Adapt.adapt(to, MS.S), Adapt.adapt(to, MS.cmask), Adapt.adapt(to, MS.lmask))
47-
Adapt.adapt_structure(to, MS::MaskedState{<:CategoricalLikelihood}) = MaskedState(Adapt.adapt(to, MS.S), Adapt.adapt(to, MS.cmask), Adapt.adapt(to, MS.lmask))
4817
Adapt.adapt_structure(to, S::ForwardBackward.ManifoldState) = ForwardBackward.ManifoldState(S.M, Adapt.adapt(to, S.state))
4918

50-
UState = Union{State,MaskedState, BackwardGuide}
51-
52-
ForwardBackward.tensor(X::MaskedState) = tensor(X.S)
53-
54-
import Base.copy
55-
copy(X::MaskedState) = MaskedState(copy(X.S), copy(X.cmask), copy(X.lmask))
56-
5719
"""
58-
endslices(a,m)
20+
onehot(X)
5921
60-
Returns a view of `a` where slices specified by `m` are selected. `m` can be multidimensional, but the dimensions of m must match the last dimensions of `a`.
61-
For example, if `m` is a boolean array, then `size(a)[ndims(a)-ndims(m):end] == size(m)`.
22+
Rerturns a state where `X.state` is a onehot array.
6223
"""
63-
endslices(a,m) = @view a[ntuple(Returns(:),ndims(a)-ndims(m))...,m]
24+
onehot(X::DiscreteState{<:AbstractArray{<:Integer}}) = DiscreteState(X.K, onehotbatch(X.state, 1:X.K))
25+
onehot(X::DiscreteState{<:OneHotArray}) = X
26+
ForwardBackward.stochastic(T::Type, o::DiscreteState{<:OneHotArray}) = CategoricalLikelihood(T.(o.state .+ 0), zeros(T, size(o.state)[2:end]...))
6427

6528
"""
66-
cmask!(Xt_state, X1_state, cmask)
67-
cmask!(Xt, X1)
29+
dense(X::DiscreteState; T = Float32)
6830
69-
Applies, in place, a conditioning mask, forcing elements (or slices) of `Xt` to be equal to `X1`, where `cmask` is 1.
31+
Converts `X` to an appropriate dense representation. If `X` is a `DiscreteState`, then `X` is converted to a `CategoricalLikelihood` with default eltype Float32.
32+
If `X` is a "onehot" CategoricalLikelihood then `X` is converted to a fully dense one.
7033
"""
71-
function cmask!(Xt_state, X1_state, cmask)
72-
endslices(Xt_state,cmask) .= endslices(X1_state,cmask)
73-
return Xt_state
74-
end
34+
dense(X::DiscreteState; T = Float32) = stochastic(T, X)
7535

76-
cmask!(Xt_state, X1_state, cmask::Nothing) = Xt_state
77-
cmask!(Xt, X1::State) = Xt
78-
cmask!(Xt, X1::StateLikelihood) = Xt
79-
cmask!(Xt, X1::MaskedState) = cmask!(Xt.S.state, X1.S.state, X1.cmask)
80-
cmask!(Xt, X1::MaskedState{<:CategoricalLikelihood}) = error("Cannot condition on a CategoricalLikelihood")
81-
cmask!(x̂₁::Tuple, x₀::Tuple) = map(cmask!, x̂₁, x₀)
82-
83-
84-
#copytensor! and predictresolve are used handle the state translation that happens in gen(...).
85-
#We want the user's X̂₁predictor, which is a DL model, to return a plain tensor (since that will be on the GPU, in the loss, etc).
86-
#This means we need to automagically create a State (typical for the continuous case) or Likelihood (typical for the discrete case) from the tensor.
87-
#But the user may return a State in the Discrete case (for massive state spaces with sub-linear sampling), and a Likelihood in the Continuous case (for variance matching models)
88-
#This also needs to handle MaskedStates (needs testing).
89-
#We need: X̂₁ = fix(X̂₁predictor(t, Xₜ))
90-
#Plan: When X̂₁predictor(t, Xₜ) is a State or Likelihood, just pass through.
91-
#When X̂₁predictor(t, Xₜ) is a plain tensor, we apply default conversion rules.
9236

9337
function copytensor!(dest, src)
9438
tensor(dest) .= tensor(src)
9539
return dest
9640
end
97-
#copytensor!(dest::Tuple, src::Tuple) = map(copytensor!, dest, src)
9841

9942
#resolveprediction exists to stop bridge from needing multiple definitions.
10043
#Tuple broadcast:
@@ -106,11 +49,8 @@ resolveprediction(X̂₁, Xₜ::State) = copytensor!(copy(Xₜ), X̂₁) #Return
10649
resolveprediction(X̂₁::State, Xₜ) = X̂₁
10750
resolveprediction(X̂₁::State, Xₜ::State) = X̂₁
10851
resolveprediction(X̂₁::StateLikelihood, Xₜ) = X̂₁
109-
110-
#Passthrough if the model returns a BackwardGuide, because we have a custom bridge for that.
111-
resolveprediction(G::BackwardGuide, Xₜ::DiscreteState) = G
112-
resolveprediction(G::BackwardGuide, Xₜ::ManifoldState) = apply_tangent_coordinates(Xₜ, G.H)
113-
#We could also add a case for where the guide is a tangent coordinate and X₀ is a ManifoldState.
52+
#Handles when the
53+
resolveprediction(G::Guide, Xₜ::ManifoldState) = apply_tangent_coordinates(Xₜ, G.H)
11454

11555

11656

@@ -124,24 +64,20 @@ If `X1` is a `MaskedState`, then `Xt` will equal `X1` where the conditioning mas
12464
The same `t` and (optionally) `t0` will be used for all elements. If you need a different `t` for each Proces/State, broadcast with `bridge.(P, X0, X1, t0, t)`.
12565
"""
12666

127-
function bridge(P::UProcess, X0::UState, X1, t0, t)
67+
function bridge(P::UProcess, X0, X1, t0, t)
12868
T = eltype(t)
12969
tF = T.(tscale(P,t) .- tscale(P,t0))
13070
tB = T.(tscale(P,1) .- tscale(P,t))
131-
endpoint_conditioned_sample(cmask!(X0,X1), X1, process(P), tF, tB)
71+
endpoint_conditioned_sample(X0, X1, process(P), tF, tB)
13272
end
13373
bridge(P, X0, X1, t) = bridge(P, X0, X1, eltype(t)(0.0), t)
13474
bridge(P::Tuple{Vararg{UProcess}}, X0::Tuple{Vararg{UState}}, X1::Tuple, t0, t) = bridge.(P, X0, X1, (t0,), (t, ))
13575

13676
#Step is like bridge (and falls back to where possible). But sometimes we only have enough to take an Euler step (which is ok when `s₂-s₁` is small).
13777
step(P, Xₜ, hat, s₁, s₂) = bridge(P, Xₜ, hat, s₁, s₂)
13878
step(P::Tuple{Vararg{UProcess}}, Xₜ::Tuple{Vararg{UState}}, hat::Tuple, s₁, s₂) = step.(P, Xₜ, hat, (s₁,), (s₂, ))
139-
#step(P::DiscreteProcess, Xₜ::DiscreteState, hat::BackwardGuide, s₁, s₂) = rand(forward(Xₜ, P, s₂ .- s₁) ⊙ hat) #<- Doesn't work
140-
79+
#step(P::DiscreteProcess, Xₜ::DiscreteState, hat::Guide, s₁, s₂) = rand(forward(Xₜ, P, s₂ .- s₁) ⊙ hat) #<- Doesn't work
14180

142-
#####Add MaskedState case(s)######
143-
144-
##################################
14581

14682
"""
14783
gen(P, X0, model, steps; tracker=Returns(nothing), midpoint = false)
@@ -156,14 +92,12 @@ function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, ste
15692
for (s₁, s₂) in zip(steps, steps[begin+1:end])
15793
t = midpoint ? (s₁ + s₂) / 2 : t = s₁
15894
hat = resolveprediction(model(t, Xₜ), Xₜ)
159-
Xₜ = step(P, Xₜ, hat, s₁, s₂)
160-
cmask!(Xₜ, X₀)
95+
Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀)
16196
tracker(t, Xₜ, hat)
16297
end
16398
return Xₜ
16499
end
165100

166-
167101
gen(P, X₀, model, args...; kwargs...) = gen((P,), (X₀,), (t, Xₜ) -> (model(t[1], Xₜ[1]),), args...; kwargs...)[1]
168102

169103
struct Tracker <: Function
@@ -183,4 +117,43 @@ end
183117

184118
function stack_tracker(tracker, field; tuple_index = 1)
185119
return stack([tensor(data[tuple_index]) for data in getproperty(tracker, field)])
186-
end
120+
end
121+
122+
123+
124+
#=If we want the model to directly predict the tangent coordinates, we use:
125+
- tangent_coordinates outside the gradient call to get the thing the model will predict
126+
- apply_tangent_coordinates during gen, to provide X̂₁ when the model is predicting the tangent coordinates
127+
- the loss should just be the mse between the predicted tangent coordinates and the true tangent coordinates
128+
Note: this gives you an invariance for free, since the model is predicting the change from Xt that results in X1.
129+
=#
130+
"""
131+
tangent_guide(Xt::ManifoldState, X1::ManifoldState)
132+
133+
Computes the coordinate vector (in the default basis) pointing from `Xt` to `X1`.
134+
"""
135+
function tangent_guide(Xt::ManifoldState, X1::ManifoldState; inverse_retraction_method=default_inverse_retraction_method(X1.M))
136+
T = eltype(tensor(X1))
137+
d = manifold_dimension(X1.M)
138+
ξ = zeros(T, d, size(Xt.state)...)
139+
temp_retract = inverse_retract(X1.M, Xt.state[1], X1.state[1], inverse_retraction_method)
140+
for ind in eachindex(Xt.state)
141+
inverse_retract!(X1.M, temp_retract, Xt.state[ind], X1.state[ind], inverse_retraction_method)
142+
ξ[:,ind] .= get_coordinates(X1.M, Xt.state[ind], temp_retract)
143+
end
144+
return ξ
145+
end
146+
147+
"""
148+
apply_tangent_coordinates(Xt::ManifoldState, ξ; retraction_method=default_retraction_method(Xt.M))
149+
150+
returns `X̂₁` where each point is the result of retracting `Xt` by the corresponding tangent coordinate vector `ξ`.
151+
"""
152+
function apply_tangent_coordinates(Xt::ManifoldState, ξ; retraction_method=default_retraction_method(Xt.M))
153+
X̂₁ = copy(Xt)
154+
for ind in eachindex(Xt.state)
155+
X = get_vector(Xt.M, Xt.state[ind], ξ[:,ind])
156+
retract!(Xt.M, X̂₁.state[ind], Xt.state[ind], X, retraction_method)
157+
end
158+
return X̂₁
159+
end

0 commit comments

Comments
 (0)