|
| 1 | +#ToDo: Incorporate FProcesses, with their schedules. The bridge behavior should already be correct, |
| 2 | +#and the fallback doob should be correct if delta is passed through the schedule. |
| 3 | +#But for the closed form we'll need to mod the velocities per the gradient, etc, and the same when stepping. |
| 4 | + |
| 5 | +struct DoobMatchingFlow{Proc} <: Process |
| 6 | + P::Proc |
| 7 | +end |
| 8 | +export DoobMatchingFlow |
| 9 | + |
| 10 | +Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t) = bridge(p.P, x0, x1, t) |
| 11 | +#Finite diff fallback for when we don't have a closed form for the forward positive velocities: |
| 12 | +function fallback_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState; delta = eltype(t)(1e-5)) |
| 13 | + return (tensor(forward(Xt, P, delta) ⊙ backward(X1, P, (1 .- t) .- delta)) .- tensor(onehot(Xt))) ./ delta; |
| 14 | +end |
| 15 | + |
| 16 | +doob_guide(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState) = fallback_doob(P, t, Xt, X1) |
| 17 | + |
| 18 | +function closed_form_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState) |
| 19 | + tenXt = tensor(onehot(Xt)) |
| 20 | + bk = tensor(backward(X1, P, 1 .- t)) |
| 21 | + fv = forward_positive_velocities(onehot(Xt), P) |
| 22 | + positive_doob = (fv .* bk) ./ sum(bk .* tenXt, dims = 1) |
| 23 | + return positive_doob .- tenXt .* sum(positive_doob, dims = 1) |
| 24 | +end |
| 25 | + |
| 26 | +forward_positive_velocities(Xt::DiscreteState, P::PiQ)= (P.r .* (P.π ./ sum(P.π))) .* (1 .- tensor(onehot(Xt))) |
| 27 | +doob_guide(P::PiQ, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1) |
| 28 | + |
| 29 | +forward_positive_velocities(Xt::DiscreteState, P::UniformUnmasking{T}) where T = (P.μ .* T((1 ./ (Xt.K-1)))) .* (1 .- tensor(onehot(Xt))) |
| 30 | +doob_guide(P::UniformUnmasking, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1) |
| 31 | + |
| 32 | +forward_positive_velocities(Xt::DiscreteState, P::UniformDiscrete{T}) where T = (P.μ * T(1/(Xt.K*(1-1/Xt.K)))) .* (1 .- tensor(onehot(Xt))) |
| 33 | +doob_guide(P::UniformDiscrete, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1) |
| 34 | + |
| 35 | +Guide(P::DoobMatchingFlow, t, Xt::DiscreteState, X1::DiscreteState) = Flowfusion.Guide(doob_guide(P.P, t, Xt, X1)) |
| 36 | +Guide(P::DoobMatchingFlow, t, mXt::Union{MaskedState{<:DiscreteState}, DiscreteState}, mX1::MaskedState{<:DiscreteState}) = Guide(doob_guide(P.P, t, mXt, mX1), mX1.cmask, mX1.lmask) |
| 37 | + |
| 38 | +function velo_step(Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, velocity) |
| 39 | + ohXₜ = onehot(Xₜ) |
| 40 | + newXₜ = CategoricalLikelihood(eltype(delta_t).(tensor(ohXₜ) .+ (delta_t .* velocity))) |
| 41 | + clamp!(tensor(newXₜ), 0, Inf) #Because one velo will be < 0 and a large step might push Xₜ < 0 |
| 42 | + return rand(newXₜ) |
| 43 | +end |
| 44 | + |
| 45 | +step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(Xₜ, s₂ .- s₁, veloX̂₁.H) |
| 46 | +step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁, s₁, s₂) = velo_step(Xₜ, s₂ .- s₁, veloX̂₁) |
| 47 | + |
| 48 | +poisson_loss(mu, count, mask) = sum(mask .* (mu .- xlogy.(count, mu))) / sum(mask) |
0 commit comments