You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#copytensor! and predictresolve are used handle the state translation that happens in gen(...).
85
+
#We want the user's X̂₁predictor, which is a DL model, to return a plain tensor (since that will be on the GPU, in the loss, etc).
86
+
#This means we need to automagically create a State (typical for the continuous case) or Likelihood (typical for the discrete case) from the tensor.
87
+
#But the user may return a State in the Discrete case (for massive state spaces with sub-linear sampling), and a Likelihood in the Continuous case (for variance matching models)
88
+
#This also needs to handle MaskedStates (needs testing).
89
+
#We need: X̂₁ = fix(X̂₁predictor(t, Xₜ))
90
+
#Plan: When X̂₁predictor(t, Xₜ) is a State or Likelihood, just pass through.
91
+
#When X̂₁predictor(t, Xₜ) is a plain tensor, we apply default conversion rules.
#copytensor! and predictresolve are used handle the state translation that happens in gen(...).
88
-
#We want the user's X̂₁predictor, which is a DL model, to return a plain tensor (since that will be on the GPU, in the loss, etc).
89
-
#This means we need to automagically create a State (typical for the continuous case) or Likelihood (typical for the discrete case) from the tensor.
90
-
#But the user may return a State in the Discrete case (for massive state spaces with sub-linear sampling), and a Likelihood in the Continuous case (for variance matching models)
91
-
#This also needs to handle MaskedStates (needs testing).
92
-
#We need: X̂₁ = fix(X̂₁predictor(t, Xₜ))
93
-
#Plan: When X̂₁predictor(t, Xₜ) is a State or Likelihood, just pass through.
94
-
#When X̂₁predictor(t, Xₜ) is a plain tensor, we apply default conversion rules.
Constructs a sequence of (stochastic) bridges between `X0` and the predicted `X̂₁` under the process `P`.
121
-
`P`, `X0`, can also be tuples where the Nth element of `P` will be used for the Nth elements of `X0` and `X̂₁predictor`.
122
-
X̂₁predictor is a function that takes `t` (scalar) and `Xₜ` (optionally a tuple) and returns `X̂₁` (a `UState`, a flat tensor with the right shape, or a tuple of either).
123
-
If `X0` is a `MaskedState` (or has a ), then anything `X̂₁` will be conditioned on `X0` where the conditioning mask `X0.cmask` is 1.
150
+
`P`, `X0`, can also be tuples where the Nth element of `P` will be used for the Nth elements of `X0` and `model`.
151
+
model is a function that takes `t` (scalar) and `Xₜ` (optionally a tuple) and returns `hat` (a `UState`, a flat tensor with the right shape, or a tuple of either if you're combining processes).
152
+
If `X0` is a `MaskedState`, then anything in `X̂₁` will be conditioned on `X0` where the conditioning mask `X0.cmask` is 1.
floss(P::fbu(OrnsteinUhlenbeck), X̂₁, X₁::msu(ContinuousState), c) =scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
59
60
floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState), c) =scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
60
61
#For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu.
61
-
floss(P::fbu(ForwardBackward.DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) =error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.")
62
-
floss(P::fbu(ForwardBackward.DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) =scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁))
62
+
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) =error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.")
63
+
floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) =scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁))
63
64
floss(P::fbu(ManifoldProcess{Rotations(3)}), X̂₁, X₁::msu(ManifoldState{Rotations(3)}), c) =scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁))
64
65
floss(P::fbu(ManifoldProcess{SpecialOrthogonal(3)}), X̂₁, X₁::msu(ManifoldState{SpecialOrthogonal(3)}), c) =scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁))
65
66
floss(P::fbu(ManifoldProcess), X̂₁, X₁::msu(ManifoldState{<:Torus}), c) =scaledmaskedmean(msta(X̂₁, X₁), c, getlmask(X₁))
@@ -70,6 +71,7 @@ floss(P::fbu(ManifoldProcess), X̂₁, X₁::msu(ManifoldState{<:Torus}), c) = s
70
71
Where `ξhat` is the predicted tangent coordinates, and `ξ` is the true tangent coordinates.
0 commit comments