Skip to content

Commit 72546b4

Browse files
committed
1 parent 909e541 commit 72546b4

File tree

6 files changed

+300
-51
lines changed

6 files changed

+300
-51
lines changed

examples/discrete.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using Pkg
2+
Pkg.activate(".")
3+
using Revise
4+
Pkg.develop(path="../../ForwardBackward/")
5+
Pkg.develop(path="../")
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
7+
8+
struct DModel{A}
9+
layers::A
10+
end
11+
12+
Flux.@layer DModel
13+
14+
function DModel(; embeddim = 64, l = 2, K = 32, layers = 5)
15+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 2.0f0), Dense(embeddim => embeddim, leakyrelu))
16+
embed_char = Dense(K => embeddim, bias = false)
17+
mix = Dense(l*embeddim => embeddim, leakyrelu)
18+
ffs = [Dense(embeddim => embeddim, leakyrelu) for _ in 1:layers]
19+
decode = Dense(embeddim => l*K)
20+
layers = (; embed_time, embed_char, mix, ffs, decode)
21+
DModel(layers)
22+
end
23+
24+
function (f::DModel)(t, Xt)
25+
l = f.layers
26+
tXt = tensor(Xt)
27+
len = size(tXt)[end]
28+
tv = zero(similar(Float32.(tXt), 1, len)) .+ expand(t, 2)
29+
x = l.mix(reshape(l.embed_char(tXt), :, len)) .+ l.embed_time(tv)
30+
for ff in l.ffs
31+
x = x .+ ff(x)
32+
end
33+
reshape(l.decode(x), :, 2, len)
34+
end
35+
36+
T = Float32
37+
n_samples = 1000
38+
39+
sampleX1(n_samples) = Flowfusion.random_discrete_cat(n_samples)
40+
sampleX0(n_samples) = rand(25:32, 2, n_samples)
41+
#sampleX0(n_samples) = [33 for _ in zeros(2, n_samples)] #Required if you want to use a UniformUnmasking process
42+
43+
P = NoisyInterpolatingDiscreteFlow(0.1)
44+
#P = InterpolatingDiscreteFlow()
45+
#P = UniformUnmasking()
46+
47+
model = DModel(embeddim = 128, l = 2, K = 33, layers = 2)
48+
49+
eta = 0.005
50+
opt_state = Flux.setup(Adam(eta), model)
51+
52+
iters = 4000
53+
for i in 1:iters
54+
#Set up a batch of training pairs, and t
55+
X1 = DiscreteState(33, sampleX1(n_samples))
56+
X0 = DiscreteState(33, sampleX0(n_samples))
57+
t = rand(T, 1, n_samples)
58+
#Construct the bridge:
59+
Xt = stochastic(Float32, bridge(P, X0, X1, t))
60+
#Gradient
61+
l,g = Flux.withgradient(model) do m
62+
floss(P, m(t,Xt), onehot(X1), t) #CE loss - Scaling with t doesn't seem critical for this one
63+
end
64+
#Update
65+
Flux.update!(opt_state, model, g[1])
66+
if i % 10 == 0
67+
if i > iters - 1000
68+
eta *= 0.975
69+
Optimisers.adjust!(opt_state, eta)
70+
end
71+
println("i: $i; Loss: $l; eta: $eta")
72+
end
73+
end
74+
75+
n_inference_samples = 10000
76+
X0 = DiscreteState(33, sampleX0(n_inference_samples))
77+
paths = Tracker()
78+
samp = gen(P, X0, (t,Xt) -> softmax(model(t,onehot(Xt))), 0f0:0.001f0:1f0, tracker = paths) #Note the softmax here
79+
80+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, color = "blue", alpha = 0.4, label = "Initial", size = (400,400), legend = :topleft, xlim = (1,33), ylim = (1,33))
81+
scatter!(samp.state[1,:],samp.state[2,:], msw = 0, color = "green", alpha = 0.04, label = :none)
82+
scatter!([-10],[-10], msw = 0, color = "green", alpha = 0.3, label = "Sampled")
83+
tvec = stack_tracker(paths, :t)
84+
xttraj = stack_tracker(paths, :xt)
85+
for i in 1:200:n_inference_samples
86+
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = :none, alpha = 0.15)
87+
end
88+
plot!([-10],[-10], color = "red", label = "Trajectory", alpha = 0.4)
89+
pl
90+
savefig("discrete_$P.svg")
91+

examples/torus.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Pkg.activate(".")
33
using Revise
44
Pkg.develop(path="../../ForwardBackward/")
55
Pkg.develop(path="../")
6-
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
6+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots, Manifolds
77

88
#Set up a Flux model: ξhat = model(t,Xt)
99
struct TModel{A}
@@ -39,8 +39,8 @@ sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))[
3939
n_samples = 500
4040

4141
M = Torus(2)
42-
#P = ManifoldProcess(0.2f0)
43-
P = Deterministic()
42+
P = ManifoldProcess(0.2f0)
43+
#P = Deterministic()
4444

4545
eta = 0.01
4646
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.00001), model)
@@ -76,7 +76,7 @@ n_inference_samples = 2000
7676
X0 = ManifoldState(M, eachcol(sampleX0(n_inference_samples)))
7777
paths = Tracker()
7878
#We wrap the model, because it was predicting tangent coordinates, not the actual state:
79-
X1pred = (t,Xt) -> apply_tangent_coordinates(Xt, model(t,tensor(Xt)))
79+
X1pred = (t,Xt) -> BackwardGuide(model(t,tensor(Xt)))
8080
samp = gen(P, X0, X1pred, 0f0:0.002f0:1f0, tracker = paths)
8181

8282
#Plot the torus, with samples, and trajectories:

src/Flowfusion.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib
44

55
include("bridge.jl")
66
include("loss.jl")
7+
include("processes.jl")
78

8-
export
9+
export
10+
#Processes not in ForwardBackward.jl
11+
InterpolatingDiscreteFlow,
12+
NoisyInterpolatingDiscreteFlow,
913
MaskedState,
1014
bridge,
1115
scalefloss,
@@ -15,6 +19,7 @@ export
1519
onehot,
1620
FProcess,
1721
tangent_coordinates,
22+
BackwardGuide,
1823
apply_tangent_coordinates,
1924
floss,
2025
tcloss

src/bridge.jl

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,37 @@ process(P::Process) = P
1717
tscale(P::Process, t) = t
1818
tscale(P::FProcess, t) = P.F.(t)
1919

20+
#=#####################
21+
Conditioning mask behavior:
22+
The typical use is that it makes sense, during training, to construct the conditioning mask on the training observation, X1.
23+
During inference, the conditioning mask (and conditioned-upon state) has to be present on X1.
24+
This dictates the behavior of the masking:
25+
- When bridge() is called, the mask, and the state where mask=1, are inherited from X1.
26+
- When gen is called, the state and mask will be propogated from X0 through all of the Xts.
27+
=#####################
2028
struct MaskedState{A,B,C}
2129
S::A #State
2230
cmask::B #Conditioning mask. 1 = Xt=X1
2331
lmask::C #Loss mask. 1 = included in loss
2432
end
2533

34+
#For when we want to predict the transitions instead of X1hat
35+
struct BackwardGuide{A}
36+
H::A
37+
end
38+
ForwardBackward.:(a::CategoricalLikelihood, b::BackwardGuide) = (a,copytensor!(copy(a),b.H))
39+
40+
#⊙ itself doesn't force the masks - it just propogates them. The forcing happens elsewhere.
41+
ForwardBackward.:(a::MaskedState, b::MaskedState; kwargs...) = MaskedState((a.S, b.S; kwargs...), a.cmask .* b.cmask, a.lmask .* b.lmask)
42+
2643
Adapt.adapt_structure(to, S::ForwardBackward.DiscreteState) = ForwardBackward.DiscreteState(S.K, Adapt.adapt(to, S.state))
2744
Adapt.adapt_structure(to, S::ForwardBackward.ContinuousState) = ForwardBackward.ContinuousState(Adapt.adapt(to, S.state))
2845
Adapt.adapt_structure(to, S::ForwardBackward.CategoricalLikelihood) = ForwardBackward.CategoricalLikelihood(Adapt.adapt(to, S.dist), Adapt.adapt(to, S.log_norm_const))
2946
Adapt.adapt_structure(to, MS::MaskedState{<:State}) = MaskedState(Adapt.adapt(to, MS.S), Adapt.adapt(to, MS.cmask), Adapt.adapt(to, MS.lmask))
3047
Adapt.adapt_structure(to, MS::MaskedState{<:CategoricalLikelihood}) = MaskedState(Adapt.adapt(to, MS.S), Adapt.adapt(to, MS.cmask), Adapt.adapt(to, MS.lmask))
3148
Adapt.adapt_structure(to, S::ForwardBackward.ManifoldState) = ForwardBackward.ManifoldState(S.M, Adapt.adapt(to, S.state))
3249

33-
UState = Union{State,MaskedState}
50+
UState = Union{State,MaskedState, BackwardGuide}
3451

3552
ForwardBackward.tensor(X::MaskedState) = tensor(X.S)
3653

@@ -63,6 +80,40 @@ cmask!(Xt, X1::MaskedState) = cmask!(Xt.S.state, X1.S.state, X1.cmask)
6380
cmask!(Xt, X1::MaskedState{<:CategoricalLikelihood}) = error("Cannot condition on a CategoricalLikelihood")
6481
cmask!(x̂₁::Tuple, x₀::Tuple) = map(cmask!, x̂₁, x₀)
6582

83+
84+
#copytensor! and predictresolve are used handle the state translation that happens in gen(...).
85+
#We want the user's X̂₁predictor, which is a DL model, to return a plain tensor (since that will be on the GPU, in the loss, etc).
86+
#This means we need to automagically create a State (typical for the continuous case) or Likelihood (typical for the discrete case) from the tensor.
87+
#But the user may return a State in the Discrete case (for massive state spaces with sub-linear sampling), and a Likelihood in the Continuous case (for variance matching models)
88+
#This also needs to handle MaskedStates (needs testing).
89+
#We need: X̂₁ = fix(X̂₁predictor(t, Xₜ))
90+
#Plan: When X̂₁predictor(t, Xₜ) is a State or Likelihood, just pass through.
91+
#When X̂₁predictor(t, Xₜ) is a plain tensor, we apply default conversion rules.
92+
93+
function copytensor!(dest, src)
94+
tensor(dest) .= tensor(src)
95+
return dest
96+
end
97+
#copytensor!(dest::Tuple, src::Tuple) = map(copytensor!, dest, src)
98+
99+
#resolveprediction exists to stop bridge from needing multiple definitions.
100+
#Tuple broadcast:
101+
resolveprediction(dest::Tuple, src::Tuple) = map(resolveprediction, dest, src)
102+
#Default if X̂₁ is a plain tensor:
103+
resolveprediction(X̂₁, Xₜ::DiscreteState) = copytensor!(stochastic(Xₜ), X̂₁) #Returns a Likelihood
104+
resolveprediction(X̂₁, Xₜ::State) = copytensor!(copy(Xₜ), X̂₁) #Returns a State - Handles Continuous and Manifold cases
105+
#Passthrough if the user returns a State or Likelihood
106+
resolveprediction(X̂₁::State, Xₜ) = X̂₁
107+
resolveprediction(X̂₁::State, Xₜ::State) = X̂₁
108+
resolveprediction(X̂₁::StateLikelihood, Xₜ) = X̂₁
109+
110+
#Passthrough if the model returns a BackwardGuide, because we have a custom bridge for that.
111+
resolveprediction(G::BackwardGuide, Xₜ::DiscreteState) = G
112+
resolveprediction(G::BackwardGuide, Xₜ::ManifoldState) = apply_tangent_coordinates(Xₜ, G.H)
113+
#We could also add a case for where the guide is a tangent coordinate and X₀ is a ManifoldState.
114+
115+
116+
66117
"""
67118
bridge(P, X0, X1, t)
68119
bridge(P, X0, X1, t0, t)
@@ -82,59 +133,38 @@ end
82133
bridge(P, X0, X1, t) = bridge(P, X0, X1, eltype(t)(0.0), t)
83134
bridge(P::Tuple{Vararg{UProcess}}, X0::Tuple{Vararg{UState}}, X1::Tuple, t0, t) = bridge.(P, X0, X1, (t0,), (t, ))
84135

136+
#Step is like bridge (and falls back to where possible). But sometimes we only have enough to take an Euler step (which is ok when `s₂-s₁` is small).
137+
step(P, Xₜ, hat, s₁, s₂) = bridge(P, Xₜ, hat, s₁, s₂)
138+
step(P::Tuple{Vararg{UProcess}}, Xₜ::Tuple{Vararg{UState}}, hat::Tuple, s₁, s₂) = step.(P, Xₜ, hat, (s₁,), (s₂, ))
139+
#step(P::DiscreteProcess, Xₜ::DiscreteState, hat::BackwardGuide, s₁, s₂) = rand(forward(Xₜ, P, s₂ .- s₁) ⊙ hat) #<- Doesn't work
85140

86141

87-
#copytensor! and predictresolve are used handle the state translation that happens in gen(...).
88-
#We want the user's X̂₁predictor, which is a DL model, to return a plain tensor (since that will be on the GPU, in the loss, etc).
89-
#This means we need to automagically create a State (typical for the continuous case) or Likelihood (typical for the discrete case) from the tensor.
90-
#But the user may return a State in the Discrete case (for massive state spaces with sub-linear sampling), and a Likelihood in the Continuous case (for variance matching models)
91-
#This also needs to handle MaskedStates (needs testing).
92-
#We need: X̂₁ = fix(X̂₁predictor(t, Xₜ))
93-
#Plan: When X̂₁predictor(t, Xₜ) is a State or Likelihood, just pass through.
94-
#When X̂₁predictor(t, Xₜ) is a plain tensor, we apply default conversion rules.
95-
96-
function copytensor!(dest, src)
97-
tensor(dest) .= tensor(src)
98-
return dest
99-
end
100-
#copytensor!(dest::Tuple, src::Tuple) = map(copytensor!, dest, src)
101-
102-
#Tuple broadcast:
103-
resolveprediction(dest::Tuple, src::Tuple) = map(resolveprediction, dest, src)
104-
#Default if X̂₁ is a plain tensor:
105-
resolveprediction(X̂₁, X₀::DiscreteState) = copytensor!(stochastic(X₀), X̂₁) #Returns a Likelihood
106-
resolveprediction(X̂₁, X₀::State) = copytensor!(copy(X₀), X̂₁) #Returns a State - Handles Continuous and Manifold cases
107-
#Passthrough if the user returns a State or Likelihood
108-
resolveprediction(X̂₁::State, X₀) = X̂₁
109-
resolveprediction(X̂₁::State, X₀::State) = X̂₁
110-
resolveprediction(X̂₁::StateLikelihood, X₀) = X̂₁
111142
#####Add MaskedState case(s)######
112143

113144
##################################
114145

115-
116-
117146
"""
118-
gen(P, X0, X̂₁predictor, steps; tracker=Returns(nothing), midpoint = false)
147+
gen(P, X0, model, steps; tracker=Returns(nothing), midpoint = false)
119148
120149
Constructs a sequence of (stochastic) bridges between `X0` and the predicted `X̂₁` under the process `P`.
121-
`P`, `X0`, can also be tuples where the Nth element of `P` will be used for the Nth elements of `X0` and `X̂₁predictor`.
122-
X̂₁predictor is a function that takes `t` (scalar) and `Xₜ` (optionally a tuple) and returns `X̂₁` (a `UState`, a flat tensor with the right shape, or a tuple of either).
123-
If `X0` is a `MaskedState` (or has a ), then anything `X̂₁` will be conditioned on `X0` where the conditioning mask `X0.cmask` is 1.
150+
`P`, `X0`, can also be tuples where the Nth element of `P` will be used for the Nth elements of `X0` and `model`.
151+
model is a function that takes `t` (scalar) and `Xₜ` (optionally a tuple) and returns `hat` (a `UState`, a flat tensor with the right shape, or a tuple of either if you're combining processes).
152+
If `X0` is a `MaskedState`, then anything in `X̂₁` will be conditioned on `X0` where the conditioning mask `X0.cmask` is 1.
124153
"""
125-
function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, X̂₁predictor, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false)
154+
function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false)
126155
Xₜ = copy.(X₀)
127156
for (s₁, s₂) in zip(steps, steps[begin+1:end])
128157
t = midpoint ? (s₁ + s₂) / 2 : t = s₁
129-
X̂₁ = resolveprediction(X̂₁predictor(t, Xₜ), X₀)
130-
cmask!(X̂₁, X₀)
131-
Xₜ = bridge(P, Xₜ, X̂₁, s₁, s₂)
132-
tracker(t, Xₜ, X̂₁)
158+
hat = resolveprediction(model(t, Xₜ), Xₜ)
159+
Xₜ = step(P, Xₜ, hat, s₁, s₂)
160+
cmask!(Xₜ, X₀)
161+
tracker(t, Xₜ, hat)
133162
end
134163
return Xₜ
135164
end
136165

137-
gen(P, X₀, X̂₁predictor, args...; kwargs...) = gen((P,), (X₀,), (t, Xₜ) -> (X̂₁predictor(t[1], Xₜ[1]),), args...; kwargs...)[1]
166+
167+
gen(P, X₀, model, args...; kwargs...) = gen((P,), (X₀,), (t, Xₜ) -> (model(t[1], Xₜ[1]),), args...; kwargs...)[1]
138168

139169
struct Tracker <: Function
140170
t::Vector

src/loss.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@ ForwardBackward.stochastic(T::Type, o::DiscreteState{<:OneHotArray}) = Categoric
2020
getlmask(X1::UState) = X1.lmask
2121
getlmask(X1::State) = nothing
2222

23-
rotangle(rots::AbstractArray{T,3}) where T = acos.(clamp.((rots[1,1,:] .+ rots[2,2,:] .+ rots[3,3,:] .- 1) ./ 2, T(-0.99), T(0.99)))
24-
rotangle(rots::AbstractArray) = reshape(rotangle(reshape(rots, 3, 3, :)), 1, size(rots)[3:end]...)
25-
torangle(x, y) = mod.(y .- x .+ π, 2π) .- π
26-
27-
23+
#This is badness that doesn't work:
24+
#rotangle(rots::AbstractArray{T,3}) where T = acos.(clamp.((rots[1,1,:] .+ rots[2,2,:] .+ rots[3,3,:] .- 1) ./ 2, T(-0.99), T(0.99)))
25+
#rotangle(rots::AbstractArray) = reshape(rotangle(reshape(rots, 3, 3, :)), 1, size(rots)[3:end]...)
26+
#torangle(x, y) = mod.(y .- x .+ π, 2π) .- π
27+
#msra(X̂₁, X₁) = rotangle(batched_mul(batched_transpose(tensor(X̂₁)), tensor(X₁))).^2 #Mean Squared Angle
28+
#msta(X̂₁, X₁) = sum(torangle(tensor(X̂₁), tensor(X₁)), dims=1).^2 #Mean Squared Toroidal Angle
2829

2930
mse(X̂₁, X₁) = abs2.(tensor(X̂₁) .- tensor(X₁)) #Mean Squared Error
3031
lce(X̂₁, X₁) = -sum(tensor(X₁) .* logsoftmax(tensor(X̂₁)), dims=1) #Logit Cross Entropy
31-
msra(X̂₁, X₁) = rotangle(batched_mul(batched_transpose(tensor(X̂₁)), tensor(X₁))).^2 #Mean Squared Angle
32-
msta(X̂₁, X₁) = sum(torangle(tensor(X̂₁), tensor(X₁)), dims=1).^2 #Mean Squared Toroidal Angle
32+
kl(P,Q) = sum(softmax(tensor(P)) .* (logsoftmax(tensor(P)) .- log.(tensor(Q))), dims=1) #Kullback-Leibler Divergence
33+
rkl(P,Q) = sum(tensor(Q) .* (log.(tensor(Q)) .- logsoftmax(tensor(P))), dims=1) #Reverse Kullback-Leibler Divergence
3334

3435
function scaledmaskedmean(l::AbstractArray{T}, c::Union{AbstractArray, Real}, m::Union{AbstractArray, Real}) where T
3536
expanded_m = expand(m, ndims(l))
@@ -58,8 +59,8 @@ floss(P::fbu(BrownianMotion), X̂₁, X₁::msu(ContinuousState),
5859
floss(P::fbu(OrnsteinUhlenbeck), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
5960
floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
6061
#For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu.
61-
floss(P::fbu(ForwardBackward.DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) = error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.")
62-
floss(P::fbu(ForwardBackward.DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) = scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁))
62+
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) = error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.")
63+
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) = scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁))
6364
floss(P::fbu(ManifoldProcess{Rotations(3)}), X̂₁, X₁::msu(ManifoldState{Rotations(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁))
6465
floss(P::fbu(ManifoldProcess{SpecialOrthogonal(3)}), X̂₁, X₁::msu(ManifoldState{SpecialOrthogonal(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁))
6566
floss(P::fbu(ManifoldProcess), X̂₁, X₁::msu(ManifoldState{<:Torus}), c) = scaledmaskedmean(msta(X̂₁, X₁), c, getlmask(X₁))
@@ -70,6 +71,7 @@ floss(P::fbu(ManifoldProcess), X̂₁, X₁::msu(ManifoldState{<:Torus}), c) = s
7071
Where `ξhat` is the predicted tangent coordinates, and `ξ` is the true tangent coordinates.
7172
"""
7273
tcloss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ, c, mask = nothing) = scaledmaskedmean(mse(ξhat, ξ), c, mask)
74+
tcloss(P::fbu(DiscreteProcess), ξhat, ξ, c, mask = nothing) = scaledmaskedmean(rkl(ξhat, ξ), c, mask)
7375

7476
#=If we want the model to directly predict the tangent coordinates, we use:
7577
- tangent_coordinates outside the gradient call to get the thing the model will predict
@@ -109,6 +111,26 @@ function apply_tangent_coordinates(Xt::ManifoldState, ξ; retraction_method=defa
109111
end
110112

111113

114+
#=
115+
#Doesn't help to do it this way
116+
"""
117+
tangent_coordinates(P::DiscreteProcess, Xt::DiscreteState, X1)
118+
119+
Computes (a weighted mixture of) Doob's h-transform(s) that would condition the current state Xt (which must be a discrete value)
120+
to end at X1 (which can be a distribution) under P. Maybe.
121+
"""
122+
function tangent_coordinates(P::DiscreteProcess, X1::DiscreteState, t)
123+
#(for a single column) for state=i at 1-t, H_j(t)/H_i(t) is the rate scaling ratio per Doob's h-transform.
124+
#If the model can learn this directly, we can gen.
125+
H = backward(X1, P, 1 .- t)
126+
scale = sum(H.dist, dims = 1)
127+
H.dist ./= scale
128+
return H
129+
end
130+
=#
131+
132+
133+
112134
########################################################################
113135
#Manifold-specific helper functions
114136
########################################################################

0 commit comments

Comments
 (0)