|
1 |
| -#=###### |
2 |
| -NOTES on what works: |
3 |
| -- Euclidean state: |
4 |
| -- - any compatible process, using floss |
5 |
| -- Manifold state: |
6 |
| -- - any compatible process, using tcloss |
7 |
| -- Discrete state: |
8 |
| -- - for a DiscreteProcess, only UniformUnmasking works properly. The rest have issues. |
9 |
| -- - works to using the ProbabilitySimplex in a ManifoldProcess. |
10 |
| -- - Either: |
11 |
| -- - - The process must have non-zero variance |
12 |
| -- - - or X0 must be a continuous distribution (ie. not discrete "corners") on the ProbabilitySimplex (in which case a deterministic process also works) |
13 |
| -=####### |
14 |
| - |
15 |
| -#This is badness that doesn't work: |
16 |
| -#rotangle(rots::AbstractArray{T,3}) where T = acos.(clamp.((rots[1,1,:] .+ rots[2,2,:] .+ rots[3,3,:] .- 1) ./ 2, T(-0.99), T(0.99))) |
17 |
| -#rotangle(rots::AbstractArray) = reshape(rotangle(reshape(rots, 3, 3, :)), 1, size(rots)[3:end]...) |
18 |
| -#torangle(x, y) = mod.(y .- x .+ π, 2π) .- π |
19 |
| -#msra(X̂₁, X₁) = rotangle(batched_mul(batched_transpose(tensor(X̂₁)), tensor(X₁))).^2 #Mean Squared Angle |
20 |
| -#msta(X̂₁, X₁) = sum(torangle(tensor(X̂₁), tensor(X₁)), dims=1).^2 #Mean Squared Toroidal Angle |
21 |
| - |
22 | 1 | mse(X̂₁, X₁) = abs2.(tensor(X̂₁) .- tensor(X₁)) #Mean Squared Error
|
23 | 2 | lce(X̂₁, X₁) = -sum(tensor(X₁) .* logsoftmax(tensor(X̂₁)), dims=1) #Logit Cross Entropy
|
24 | 3 | kl(P,Q) = sum(softmax(tensor(P)) .* (logsoftmax(tensor(P)) .- log.(tensor(Q))), dims=1) #Kullback-Leibler Divergence
|
@@ -55,45 +34,11 @@ floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState),
|
55 | 34 | #For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu.
|
56 | 35 | floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) = error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.")
|
57 | 36 | floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) = scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁))
|
58 |
| -#floss(P::fbu(ManifoldProcess{Rotations(3)}), X̂₁, X₁::msu(ManifoldState{Rotations(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁)) |
59 |
| -#floss(P::fbu(ManifoldProcess{SpecialOrthogonal(3)}), X̂₁, X₁::msu(ManifoldState{SpecialOrthogonal(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁)) |
60 |
| -#floss(P::fbu(ManifoldProcess), X̂₁, X₁::msu(ManifoldState{<:Torus}), c) = scaledmaskedmean(msta(X̂₁, X₁), c, getlmask(X₁)) |
61 |
| - |
62 | 37 | floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Union{AbstractArray, Real}) = sum(floss.(P, X̂₁, X₁, (c,)))
|
63 | 38 | floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Tuple) = sum(floss.(P, X̂₁, X₁, c))
|
64 |
| - |
65 |
| -#I should make a self-balancing loss that tracks the running mean/std and adaptively scales to balance against target weights. |
66 |
| - |
67 |
| -""" |
68 |
| - tcloss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ, c, mask = nothing) |
69 |
| -
|
70 |
| -Where `ξhat` is the predicted tangent coordinates, and `ξ` is the true tangent coordinates. |
71 |
| -""" |
72 | 39 | floss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ::Guide, c) = scaledmaskedmean(mse(ξhat, ξ.H), c, getlmask(ξ))
|
73 |
| -#tcloss(P::fbu(DiscreteProcess), ξhat, ξ, c, mask = nothing) = scaledmaskedmean(rkl(ξhat, ξ), c, mask) |
74 |
| - |
75 |
| - |
76 |
| - |
77 |
| - |
78 |
| -#= |
79 |
| -#Doesn't help to do it this way |
80 |
| -""" |
81 |
| - tangent_coordinates(P::DiscreteProcess, Xt::DiscreteState, X1) |
82 |
| -
|
83 |
| -Computes (a weighted mixture of) Doob's h-transform(s) that would condition the current state Xt (which must be a discrete value) |
84 |
| -to end at X1 (which can be a distribution) under P. Maybe. |
85 |
| -""" |
86 |
| -function tangent_coordinates(P::DiscreteProcess, X1::DiscreteState, t) |
87 |
| - #(for a single column) for state=i at 1-t, H_j(t)/H_i(t) is the rate scaling ratio per Doob's h-transform. |
88 |
| - #If the model can learn this directly, we can gen. |
89 |
| - H = backward(X1, P, 1 .- t) |
90 |
| - scale = sum(H.dist, dims = 1) |
91 |
| - H.dist ./= scale |
92 |
| - return H |
93 |
| -end |
94 |
| -=# |
95 |
| - |
96 | 40 |
|
| 41 | +#I should make a self-balancing loss that tracks the running mean/std and adaptively scales to balance against target weights. |
97 | 42 |
|
98 | 43 | ########################################################################
|
99 | 44 | #Manifold-specific helper functions
|
|
0 commit comments