1
- # ToDo: Incorporate FProcesses, with their schedules. The bridge behavior should already be correct,
2
- # and the fallback doob should be correct if delta is passed through the schedule.
3
- # But for the closed form we'll need to mod the velocities per the gradient, etc, and the same when stepping.
1
+ # Note: Haven't figured out exactly what, in the literature, this is. Not very tested!
4
2
5
- struct DoobMatchingFlow{Proc} <: Process
3
+ struct DoobMatchingFlow{Proc, B, F } <: Process
6
4
P:: Proc
5
+ onescale:: B # Controls whether the "step" is unit scale or "time remaining" scale. Need to think carefully about schedules in all this...
6
+ transform:: F # Transforms the output of the model to the rate space. Must act on the whole tensor.
7
+ # Note: losses can be compared for different transforms, but not for different onescale.
7
8
end
8
- export DoobMatchingFlow
9
+
10
+ DoobMatchingFlow (P:: DiscreteProcess ) = DoobMatchingFlow (P, true , NNlib. softplus) # x -> exp.(clamp.(x, -100, 11)) also works, but is scary
11
+ DoobMatchingFlow (P:: DiscreteProcess , transform:: Function ) = DoobMatchingFlow (P, true , transform)
12
+ DoobMatchingFlow (P:: DiscreteProcess , onescale:: Bool ) = DoobMatchingFlow (P, onescale, NNlib. softplus)
13
+
14
+ onescale (P:: DoobMatchingFlow ,t) = P. onescale ? (1 .- t) : eltype (t)(1 )
15
+ mulexpand (t,x) = expand (t, ndims (x)) .* x
9
16
10
17
Flowfusion. bridge (p:: DoobMatchingFlow , x0:: DiscreteState{<:AbstractArray{<:Signed}} , x1:: DiscreteState{<:AbstractArray{<:Signed}} , t) = bridge (p. P, x0, x1, t)
11
- # Finite diff fallback for when we don't have a closed form for the forward positive velocities:
18
+
12
19
function fallback_doob (P:: DiscreteProcess , t, Xt:: DiscreteState , X1:: DiscreteState ; delta = eltype (t)(1e-5 ))
13
20
return (tensor (forward (Xt, P, delta) ⊙ backward (X1, P, (1 .- t) .- delta)) .- tensor (onehot (Xt))) ./ delta;
14
21
end
25
32
26
33
forward_positive_velocities (Xt:: DiscreteState , P:: PiQ )= (P. r .* (P. π ./ sum (P. π))) .* (1 .- tensor (onehot (Xt)))
27
34
doob_guide (P:: PiQ , t, Xt:: DiscreteState , X1:: DiscreteState ) = closed_form_doob (P, t, Xt, X1)
28
-
29
35
forward_positive_velocities (Xt:: DiscreteState , P:: UniformUnmasking{T} ) where T = (P. μ .* T ((1 ./ (Xt. K- 1 )))) .* (1 .- tensor (onehot (Xt)))
30
36
doob_guide (P:: UniformUnmasking , t, Xt:: DiscreteState , X1:: DiscreteState ) = closed_form_doob (P, t, Xt, X1)
31
-
32
37
forward_positive_velocities (Xt:: DiscreteState , P:: UniformDiscrete{T} ) where T = (P. μ * T (1 / (Xt. K* (1 - 1 / Xt. K)))) .* (1 .- tensor (onehot (Xt)))
33
38
doob_guide (P:: UniformDiscrete , t, Xt:: DiscreteState , X1:: DiscreteState ) = closed_form_doob (P, t, Xt, X1)
34
39
35
- Guide (P:: DoobMatchingFlow , t, Xt:: DiscreteState , X1:: DiscreteState ) = Flowfusion. Guide (doob_guide (P. P, t, Xt, X1))
36
- Guide (P:: DoobMatchingFlow , t, mXt:: Union{MaskedState{<:DiscreteState}, DiscreteState} , mX1:: MaskedState{<:DiscreteState} ) = Guide (doob_guide (P. P, t, mXt, mX1), mX1. cmask, mX1. lmask)
40
+ Guide (P:: DoobMatchingFlow , t, Xt:: DiscreteState , X1:: DiscreteState ) = Flowfusion. Guide (mulexpand ( onescale (P, t), doob_guide (P. P, t, Xt, X1) ))
41
+ Guide (P:: DoobMatchingFlow , t, mXt:: Union{MaskedState{<:DiscreteState}, DiscreteState} , mX1:: MaskedState{<:DiscreteState} ) = Guide (mulexpand ( onescale (P, t), doob_guide (P. P, t, mXt, mX1) ), mX1. cmask, mX1. lmask)
37
42
38
- function velo_step (Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , delta_t, velocity)
43
+ function rate_constraint (Xt, X̂₁, f)
44
+ posQt = f (X̂₁) .* (1 .- Xt)
45
+ diagQt = - sum (posQt, dims = 1 ) .* Xt
46
+ return posQt .+ diagQt
47
+ end
48
+
49
+ function velo_step (P, Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , delta_t, log_velocity, scale)
39
50
ohXₜ = onehot (Xₜ)
51
+ velocity = rate_constraint (tensor (ohXₜ), log_velocity, P. transform) .* scale
40
52
newXₜ = CategoricalLikelihood (eltype (delta_t).(tensor (ohXₜ) .+ (delta_t .* velocity)))
41
53
clamp! (tensor (newXₜ), 0 , Inf ) # Because one velo will be < 0 and a large step might push Xₜ < 0
42
54
return rand (newXₜ)
43
55
end
44
56
45
- step (P:: DoobMatchingFlow , Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , veloX̂₁:: Flowfusion.Guide , s₁, s₂) = velo_step (Xₜ, s₂ .- s₁, veloX̂₁. H)
46
- step (P:: DoobMatchingFlow , Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , veloX̂₁, s₁, s₂) = velo_step (Xₜ, s₂ .- s₁, veloX̂₁)
57
+ step (P:: DoobMatchingFlow , Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , veloX̂₁:: Flowfusion.Guide , s₁, s₂) = velo_step (P, Xₜ, s₂ .- s₁, veloX̂₁. H, expand (1 ./ onescale (P, s₁), ndims (veloX̂₁. H)))
58
+ step (P:: DoobMatchingFlow , Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , veloX̂₁, s₁, s₂) = velo_step (P, Xₜ, s₂ .- s₁, veloX̂₁, expand (1 ./ onescale (P, s₁), ndims (veloX̂₁)))
59
+
60
+ function cgm_dloss (P, Xt, X̂₁, doobX₁)
61
+ Qt = P. transform (X̂₁)
62
+ return sum ((1 .- Xt) .* (Qt .- xlogy .(doobX₁, Qt)), dims = 1 ) # <- note, diagonals ignored; implicit zero sum
63
+ end
47
64
48
- poisson_loss (mu, count, mask ) = sum (mask .* (mu .- xlogy .(count, mu))) / sum (mask )
65
+ floss (P :: Flowfusion.fbu (DoobMatchingFlow), Xt :: Flowfusion.msu (DiscreteState), X̂₁, X₁ :: Guide , c ) = Flowfusion . scaledmaskedmean ( cgm_dloss (P, tensor (Xt), tensor (X̂₁), X₁ . H), c, Flowfusion . getlmask (X₁) )
0 commit comments