Skip to content

Commit c527044

Browse files
authored
Add DoobMatchingFlow
1 parent b4525f7 commit c527044

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

src/doob.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ToDo: Incorporate FProcesses, with their schedules. The bridge behavior should already be correct,
2+
#and the fallback doob should be correct if delta is passed through the schedule.
3+
#But for the closed form we'll need to mod the velocities per the gradient, etc, and the same when stepping.
4+
5+
struct DoobMatchingFlow{Proc} <: Process
6+
P::Proc
7+
end
8+
export DoobMatchingFlow
9+
10+
Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t) = bridge(p.P, x0, x1, t)
11+
#Finite diff fallback for when we don't have a closed form for the forward positive velocities:
12+
function fallback_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState; delta = eltype(t)(1e-5))
13+
return (tensor(forward(Xt, P, delta) backward(X1, P, (1 .- t) .- delta)) .- tensor(onehot(Xt))) ./ delta;
14+
end
15+
16+
doob_guide(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState) = fallback_doob(P, t, Xt, X1)
17+
18+
function closed_form_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState)
19+
tenXt = tensor(onehot(Xt))
20+
bk = tensor(backward(X1, P, 1 .- t))
21+
fv = forward_positive_velocities(onehot(Xt), P)
22+
positive_doob = (fv .* bk) ./ sum(bk .* tenXt, dims = 1)
23+
return positive_doob .- tenXt .* sum(positive_doob, dims = 1)
24+
end
25+
26+
forward_positive_velocities(Xt::DiscreteState, P::PiQ)= (P.r .* (P.π ./ sum(P.π))) .* (1 .- tensor(onehot(Xt)))
27+
doob_guide(P::PiQ, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
28+
29+
forward_positive_velocities(Xt::DiscreteState, P::UniformUnmasking{T}) where T = (P.μ .* T((1 ./ (Xt.K-1)))) .* (1 .- tensor(onehot(Xt)))
30+
doob_guide(P::UniformUnmasking, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
31+
32+
forward_positive_velocities(Xt::DiscreteState, P::UniformDiscrete{T}) where T = (P.μ * T(1/(Xt.K*(1-1/Xt.K)))) .* (1 .- tensor(onehot(Xt)))
33+
doob_guide(P::UniformDiscrete, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
34+
35+
Guide(P::DoobMatchingFlow, t, Xt::DiscreteState, X1::DiscreteState) = Flowfusion.Guide(doob_guide(P.P, t, Xt, X1))
36+
Guide(P::DoobMatchingFlow, t, mXt::Union{MaskedState{<:DiscreteState}, DiscreteState}, mX1::MaskedState{<:DiscreteState}) = Guide(doob_guide(P.P, t, mXt, mX1), mX1.cmask, mX1.lmask)
37+
38+
function velo_step(Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, velocity)
39+
ohXₜ = onehot(Xₜ)
40+
newXₜ = CategoricalLikelihood(eltype(delta_t).(tensor(ohXₜ) .+ (delta_t .* velocity)))
41+
clamp!(tensor(newXₜ), 0, Inf) #Because one velo will be < 0 and a large step might push Xₜ < 0
42+
return rand(newXₜ)
43+
end
44+
45+
step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(Xₜ, s₂ .- s₁, veloX̂₁.H)
46+
step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁, s₁, s₂) = velo_step(Xₜ, s₂ .- s₁, veloX̂₁)
47+
48+
poisson_loss(mu, count, mask) = sum(mask .* (mu .- xlogy.(count, mu))) / sum(mask)

0 commit comments

Comments
 (0)