|
1 | | -#=###### |
2 | | -NOTES on what works: |
3 | | -- Euclidean state: |
4 | | -- - any compatible process, using floss |
5 | | -- Manifold state: |
6 | | -- - any compatible process, using tcloss |
7 | | -- Discrete state: |
8 | | -- - for a DiscreteProcess, only UniformUnmasking works properly. The rest have issues. |
9 | | -- - works to using the ProbabilitySimplex in a ManifoldProcess. |
10 | | -- - Either: |
11 | | -- - - The process must have non-zero variance |
12 | | -- - - or X0 must be a continuous distribution (ie. not discrete "corners") on the ProbabilitySimplex (in which case a deterministic process also works) |
13 | | -=####### |
14 | | - |
15 | | -#This is badness that doesn't work: |
16 | | -#rotangle(rots::AbstractArray{T,3}) where T = acos.(clamp.((rots[1,1,:] .+ rots[2,2,:] .+ rots[3,3,:] .- 1) ./ 2, T(-0.99), T(0.99))) |
17 | | -#rotangle(rots::AbstractArray) = reshape(rotangle(reshape(rots, 3, 3, :)), 1, size(rots)[3:end]...) |
18 | | -#torangle(x, y) = mod.(y .- x .+ π, 2π) .- π |
19 | | -#msra(X̂₁, X₁) = rotangle(batched_mul(batched_transpose(tensor(X̂₁)), tensor(X₁))).^2 #Mean Squared Angle |
20 | | -#msta(X̂₁, X₁) = sum(torangle(tensor(X̂₁), tensor(X₁)), dims=1).^2 #Mean Squared Toroidal Angle |
21 | | - |
22 | 1 | mse(X̂₁, X₁) = abs2.(tensor(X̂₁) .- tensor(X₁)) #Mean Squared Error |
23 | 2 | lce(X̂₁, X₁) = -sum(tensor(X₁) .* logsoftmax(tensor(X̂₁)), dims=1) #Logit Cross Entropy |
24 | 3 | kl(P,Q) = sum(softmax(tensor(P)) .* (logsoftmax(tensor(P)) .- log.(tensor(Q))), dims=1) #Kullback-Leibler Divergence |
@@ -55,45 +34,11 @@ floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState), |
55 | 34 | #For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu. |
56 | 35 | floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:AbstractArray{<:Integer}}), c) = error("X₁ needs to be onehot encoded with `onehot(X₁)`. You might need to do this before moving it to the GPU.") |
57 | 36 | floss(P::fbu(DiscreteProcess), X̂₁, X₁::msu(DiscreteState{<:OneHotArray}), c) = scaledmaskedmean(lce(X̂₁, X₁), c, getlmask(X₁)) |
58 | | -#floss(P::fbu(ManifoldProcess{Rotations(3)}), X̂₁, X₁::msu(ManifoldState{Rotations(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁)) |
59 | | -#floss(P::fbu(ManifoldProcess{SpecialOrthogonal(3)}), X̂₁, X₁::msu(ManifoldState{SpecialOrthogonal(3)}), c) = scaledmaskedmean(msra(X̂₁, X₁), c, getlmask(X₁)) |
60 | | -#floss(P::fbu(ManifoldProcess), X̂₁, X₁::msu(ManifoldState{<:Torus}), c) = scaledmaskedmean(msta(X̂₁, X₁), c, getlmask(X₁)) |
61 | | - |
62 | 37 | floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Union{AbstractArray, Real}) = sum(floss.(P, X̂₁, X₁, (c,))) |
63 | 38 | floss(P::Tuple, X̂₁::Tuple, X₁::Tuple, c::Tuple) = sum(floss.(P, X̂₁, X₁, c)) |
64 | | - |
65 | | -#I should make a self-balancing loss that tracks the running mean/std and adaptively scales to balance against target weights. |
66 | | - |
67 | | -""" |
68 | | - tcloss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ, c, mask = nothing) |
69 | | -
|
70 | | -Where `ξhat` is the predicted tangent coordinates, and `ξ` is the true tangent coordinates. |
71 | | -""" |
72 | 39 | floss(P::Union{fbu(ManifoldProcess), fbu(Deterministic)}, ξhat, ξ::Guide, c) = scaledmaskedmean(mse(ξhat, ξ.H), c, getlmask(ξ)) |
73 | | -#tcloss(P::fbu(DiscreteProcess), ξhat, ξ, c, mask = nothing) = scaledmaskedmean(rkl(ξhat, ξ), c, mask) |
74 | | - |
75 | | - |
76 | | - |
77 | | - |
78 | | -#= |
79 | | -#Doesn't help to do it this way |
80 | | -""" |
81 | | - tangent_coordinates(P::DiscreteProcess, Xt::DiscreteState, X1) |
82 | | -
|
83 | | -Computes (a weighted mixture of) Doob's h-transform(s) that would condition the current state Xt (which must be a discrete value) |
84 | | -to end at X1 (which can be a distribution) under P. Maybe. |
85 | | -""" |
86 | | -function tangent_coordinates(P::DiscreteProcess, X1::DiscreteState, t) |
87 | | - #(for a single column) for state=i at 1-t, H_j(t)/H_i(t) is the rate scaling ratio per Doob's h-transform. |
88 | | - #If the model can learn this directly, we can gen. |
89 | | - H = backward(X1, P, 1 .- t) |
90 | | - scale = sum(H.dist, dims = 1) |
91 | | - H.dist ./= scale |
92 | | - return H |
93 | | -end |
94 | | -=# |
95 | | - |
96 | 40 |
|
| 41 | +#I should make a self-balancing loss that tracks the running mean/std and adaptively scales to balance against target weights. |
97 | 42 |
|
98 | 43 | ######################################################################## |
99 | 44 | #Manifold-specific helper functions |
|
0 commit comments