Skip to content

Commit c032366

Browse files
authored
Adding type param to NoisyInterpolatingDiscreteFlow
1 parent ec4030f commit c032366

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

src/processes.jl

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,30 @@ end
4141

4242

4343
"""
44-
NoisyInterpolatingDiscreteFlow(κ₁, κ₂, dκ₁, dκ₂)
45-
NoisyInterpolatingDiscreteFlow(noise, K = 1) - Uses default cosine schedule, where `noise` is the maximum amplitude of the uniform noise component.
44+
NoisyInterpolatingDiscreteFlow(κ₁, κ₂, dκ₁, dκ₂, dummy_token)
45+
NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token = nothing) - Uses default cosine schedule, where `noise` is the maximum amplitude of the uniform noise component.
4646
NoisyInterpolatingDiscreteFlow() - Uses default cosine schedule and noise = 0.2.
4747
4848
A convex mixture of X0, uniform noise, and X1. Equation 10 in https://arxiv.org/pdf/2407.15595
4949
Compared to InterpolatingDiscreteFlow, it encourages the model to make multiple switches during inference.
5050
κ₁, κ₂ are the schedules for target token interpolation and uniform noise probability.
5151
dκ₁, dκ₂ are the derivatives of κ₁, κ₂.
5252
Defaults to using a cosine schedule. `K=2` will resolve the discrete states later than `K=1`.
53+
If K>1 things might break if your X0 is not the `dummy_token` (also called the masked token) which should be passed to NoisyInterpolatingDiscreteFlow.
5354
"""
54-
55-
NoisyInterpolatingDiscreteFlow(noise, K = 1) = NoisyInterpolatingDiscreteFlow(
56-
t -> oftype(t,(1 - cos((π/2)*t))^K), #K1
57-
t -> oftype(t,(noise * sin*t))), #K2
58-
t -> oftype(t,(K */2) * sin((π/2) * t) * (1 - cos((π/2) * t))^(K - 1))), #dK1
59-
t -> oftype(t,(noise*π*cos*t))) #dK2
60-
)
61-
NoisyInterpolatingDiscreteFlow() = NoisyInterpolatingDiscreteFlow(0.2)
62-
55+
function NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token::T = nothing) where T
56+
if (K > 1 && isnothing(dummy_token))
57+
@warn "NoisyInterpolatingDiscreteFlow: If K>1 things might break if your X0 is not the `dummy_token` (which should also be passed to NoisyInterpolatingDiscreteFlow)."
58+
end
59+
return NoisyInterpolatingDiscreteFlow{T}(
60+
t -> oftype(t,(1 - cos((π/2)*t))^K), #K1
61+
t -> oftype(t,(noise * sin*t))), #K2
62+
t -> oftype(t,(K */2) * sin((π/2) * t) * (1 - cos((π/2) * t))^(K - 1))), #dK1
63+
t -> oftype(t,(noise*π*cos*t))), #dK2
64+
dummy_token
65+
)
66+
end
67+
NoisyInterpolatingDiscreteFlow() = NoisyInterpolatingDiscreteFlow{Nothing}(0.2)
6368
function bridge(p::NoisyInterpolatingDiscreteFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t)
6469
D = size(x0.state)
6570
ts = expand(t, ndims(x0.state))
@@ -78,8 +83,7 @@ function bridge(p::NoisyInterpolatingDiscreteFlow, x0::DiscreteState{<:AbstractA
7883
end
7984
return Xt
8085
end
81-
82-
function step(P::NoisyInterpolatingDiscreteFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, X̂₁, s₁, s₂)
86+
function step(P::NoisyInterpolatingDiscreteFlow{Nothing}, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, X̂₁, s₁, s₂)
8387
T = eltype(s₁)
8488
Δt = s₂ .- s₁
8589
ohXₜ = onehot(Xₜ)
@@ -98,3 +102,29 @@ function step(P::NoisyInterpolatingDiscreteFlow, Xₜ::DiscreteState{<:AbstractA
98102
clamp!(tensor(newXₜ), 0, Inf)
99103
return rand(newXₜ)
100104
end
105+
function step(P::NoisyInterpolatingDiscreteFlow{<:Integer}, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, X̂₁, s₁, s₂)
106+
T = eltype(s₁)
107+
Δt = s₂ .- s₁
108+
ohXₜ = onehot(Xₜ)
109+
pu = T(1/Xₜ.K)
110+
eps = T(1e-10)
111+
κ1 = P.κ₁.(s₁)
112+
κ2 = P.κ₂.(s₁)
113+
κ3 = (1 .- (κ1 .+ κ2)) # κ₃(t)=1-κ₁(t)-κ₂(t)
114+
dκ1 = P.dκ₁.(s₁)
115+
dκ2 = P.dκ₂.(s₁)
116+
dκ3 = .- (dκ1 .+ dκ2) # Because dκ₃ = - (dκ₁+dκ₂)
117+
#Theorem 3 applied to equation 10 in https://arxiv.org/pdf/2407.15595
118+
r1 = dκ1 ./ (eps .+ κ1)
119+
r2 = dκ2 ./ (eps .+ κ2)
120+
r3 = dκ3 ./ (eps .+ κ3)
121+
bt = min.(r1,r2, r3) #b_t = min_j dκ_j/κ_j
122+
a1 = dκ1 .- κ1 .* bt # component 1 (denoiser)
123+
a2 = dκ2 .- κ2 .* bt # component 2 (uniform)
124+
a3 = dκ3 .- κ3 .* bt # component 3 (dummy/mask)
125+
velo = a1 .* tensor(X̂₁) .+ a2 .* pu .+ bt .* tensor(ohXₜ)
126+
selectdim(velo,1,P.mask_token) .+= a3 #Adding the mask token compoenent to the correct tensor slice
127+
newXₜ = CategoricalLikelihood(eltype(s₁).(tensor(ohXₜ) .+ (Δt .* velo)))
128+
clamp!(tensor(newXₜ), 0, Inf)
129+
return rand(newXₜ)
130+
end

src/types.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ struct InterpolatingDiscreteFlow <: ConvexInterpolatingDiscreteFlow
5757
κ̇::Function
5858
end
5959

60-
struct NoisyInterpolatingDiscreteFlow <: ConvexInterpolatingDiscreteFlow
60+
struct NoisyInterpolatingDiscreteFlow{T} <: ConvexInterpolatingDiscreteFlow
6161
κ₁::Function # schedule for target token interpolation
6262
κ₂::Function # schedule for uniform noise probability
6363
dκ₁::Function # derivative of κ₁
6464
dκ₂::Function # derivative of κ₂
65+
mask_token::T # the token that is used for the X0 state
6566
end

0 commit comments

Comments
 (0)