Skip to content

Commit 3bac5ed

Browse files
authored
Merge pull request #6 from MurrellGroup/dummy_token
Adding type param to NoisyInterpolatingDiscreteFlow
2 parents ec4030f + 6d4961e commit 3bac5ed

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

src/processes.jl

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,31 @@ end
4141

4242

4343
"""
44-
NoisyInterpolatingDiscreteFlow(κ₁, κ₂, dκ₁, dκ₂)
45-
NoisyInterpolatingDiscreteFlow(noise, K = 1) - Uses default cosine schedule, where `noise` is the maximum amplitude of the uniform noise component.
44+
NoisyInterpolatingDiscreteFlow(κ₁, κ₂, dκ₁, dκ₂, dummy_token)
45+
NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token = nothing) - Uses default cosine schedule, where `noise` is the maximum amplitude of the uniform noise component.
4646
NoisyInterpolatingDiscreteFlow() - Uses default cosine schedule and noise = 0.2.
4747
4848
A convex mixture of X0, uniform noise, and X1. Equation 10 in https://arxiv.org/pdf/2407.15595
4949
Compared to InterpolatingDiscreteFlow, it encourages the model to make multiple switches during inference.
5050
κ₁, κ₂ are the schedules for target token interpolation and uniform noise probability.
5151
dκ₁, dκ₂ are the derivatives of κ₁, κ₂.
5252
Defaults to using a cosine schedule. `K=2` will resolve the discrete states later than `K=1`.
53+
If K>1 things might break if your X0 is not the `dummy_token` (also called the masked token) which should be passed to NoisyInterpolatingDiscreteFlow.
5354
"""
54-
55-
NoisyInterpolatingDiscreteFlow(noise, K = 1) = NoisyInterpolatingDiscreteFlow(
56-
t -> oftype(t,(1 - cos((π/2)*t))^K), #K1
57-
t -> oftype(t,(noise * sin*t))), #K2
58-
t -> oftype(t,(K */2) * sin((π/2) * t) * (1 - cos((π/2) * t))^(K - 1))), #dK1
59-
t -> oftype(t,(noise*π*cos*t))) #dK2
60-
)
61-
NoisyInterpolatingDiscreteFlow() = NoisyInterpolatingDiscreteFlow(0.2)
62-
55+
function NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token::T = nothing) where T
56+
if (K > 1 && isnothing(dummy_token))
57+
@warn "NoisyInterpolatingDiscreteFlow: If K>1 things might break if your X0 is not the `dummy_token` (which should also be passed to NoisyInterpolatingDiscreteFlow)."
58+
end
59+
return NoisyInterpolatingDiscreteFlow{T}(
60+
t -> oftype(t,(1 - cos((π/2)*t))^K), #K1
61+
t -> oftype(t,(noise * sin*t))), #K2
62+
t -> oftype(t,(K */2) * sin((π/2) * t) * (1 - cos((π/2) * t))^(K - 1))), #dK1
63+
t -> oftype(t,(noise*π*cos*t))), #dK2
64+
dummy_token
65+
)
66+
end
67+
NoisyInterpolatingDiscreteFlow() = NoisyInterpolatingDiscreteFlow{Nothing}(0.2)
68+
NoisyInterpolatingDiscreteFlow(noise, power) = NoisyInterpolatingDiscreteFlow(noise, K = power)
6369
function bridge(p::NoisyInterpolatingDiscreteFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t)
6470
D = size(x0.state)
6571
ts = expand(t, ndims(x0.state))
@@ -78,8 +84,7 @@ function bridge(p::NoisyInterpolatingDiscreteFlow, x0::DiscreteState{<:AbstractA
7884
end
7985
return Xt
8086
end
81-
82-
function step(P::NoisyInterpolatingDiscreteFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, X̂₁, s₁, s₂)
87+
function step(P::NoisyInterpolatingDiscreteFlow{Nothing}, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, X̂₁, s₁, s₂)
8388
T = eltype(s₁)
8489
Δt = s₂ .- s₁
8590
ohXₜ = onehot(Xₜ)
@@ -98,3 +103,29 @@ function step(P::NoisyInterpolatingDiscreteFlow, Xₜ::DiscreteState{<:AbstractA
98103
clamp!(tensor(newXₜ), 0, Inf)
99104
return rand(newXₜ)
100105
end
106+
function step(P::NoisyInterpolatingDiscreteFlow{<:Integer}, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, X̂₁, s₁, s₂)
107+
T = eltype(s₁)
108+
Δt = s₂ .- s₁
109+
ohXₜ = onehot(Xₜ)
110+
pu = T(1/Xₜ.K)
111+
eps = T(1e-10)
112+
κ1 = P.κ₁.(s₁)
113+
κ2 = P.κ₂.(s₁)
114+
κ3 = (1 .- (κ1 .+ κ2)) # κ₃(t)=1-κ₁(t)-κ₂(t)
115+
dκ1 = P.dκ₁.(s₁)
116+
dκ2 = P.dκ₂.(s₁)
117+
dκ3 = .- (dκ1 .+ dκ2) # Because dκ₃ = - (dκ₁+dκ₂)
118+
#Theorem 3 applied to equation 10 in https://arxiv.org/pdf/2407.15595
119+
r1 = dκ1 ./ (eps .+ κ1)
120+
r2 = dκ2 ./ (eps .+ κ2)
121+
r3 = dκ3 ./ (eps .+ κ3)
122+
bt = min.(r1,r2, r3) #b_t = min_j dκ_j/κ_j
123+
a1 = dκ1 .- κ1 .* bt # component 1 (denoiser)
124+
a2 = dκ2 .- κ2 .* bt # component 2 (uniform)
125+
a3 = dκ3 .- κ3 .* bt # component 3 (dummy/mask)
126+
velo = a1 .* tensor(X̂₁) .+ a2 .* pu .+ bt .* tensor(ohXₜ)
127+
selectdim(velo,1,P.mask_token) .+= a3 #Adding the mask token compoenent to the correct tensor slice
128+
newXₜ = CategoricalLikelihood(eltype(s₁).(tensor(ohXₜ) .+ (Δt .* velo)))
129+
clamp!(tensor(newXₜ), 0, Inf)
130+
return rand(newXₜ)
131+
end

src/types.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ struct InterpolatingDiscreteFlow <: ConvexInterpolatingDiscreteFlow
5757
κ̇::Function
5858
end
5959

60-
struct NoisyInterpolatingDiscreteFlow <: ConvexInterpolatingDiscreteFlow
60+
struct NoisyInterpolatingDiscreteFlow{T} <: ConvexInterpolatingDiscreteFlow
6161
κ₁::Function # schedule for target token interpolation
6262
κ₂::Function # schedule for uniform noise probability
6363
dκ₁::Function # derivative of κ₁
6464
dκ₂::Function # derivative of κ₂
65+
mask_token::T # the token that is used for the X0 state
6566
end

0 commit comments

Comments
 (0)