Skip to content

Commit 22dd738

Browse files
authored
Deconv using CalcFactor dispatch (#1789)
1 parent ebe8f88 commit 22dd738

File tree

4 files changed

+82
-39
lines changed

4 files changed

+82
-39
lines changed

src/entities/CalcFactor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct CalcFactorNormSq{
5858
# which index is being solved for?
5959
solvefor::Int
6060
manifold::M
61-
measurement::MEAS
61+
measurement::MEAS #TBD make measurement only one sample per calc factor
6262
slack::S
6363
end
6464

src/services/DeconvUtils.jl

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,7 @@ function approxDeconv(
9292

9393
# find solution via SubArray view pointing to original memory location
9494
if fcttype isa AbstractManifoldMinimize
95-
sfidx = ccw.varidx[]
96-
ts = _solveLambdaNumericMeas(
97-
fcttype,
98-
hypoObj,
99-
res_,
100-
measurement[idx],
101-
getVariableType(ccw.fullvariables[sfidx]), # ccw.vartypes[sfidx](),
102-
islen1,
103-
)
104-
# copyto!(target_smpl, ts)
105-
measurement[idx] = ts
95+
error("Fix dispatch on AbstractManifoldMinimize")
10696
else
10797
ts = _solveLambdaNumeric(fcttype, hypoObj, res_, measurement[idx], islen1)
10898
measurement[idx] = ts
@@ -115,6 +105,62 @@ function approxDeconv(
115105
return measurement, fctSmpls
116106
end
117107

108+
# TBD deprecate use of xDim
109+
function approxDeconv(
110+
fcto::DFGFactor{<:CommonConvWrapper{<:AbstractManifoldMinimize}},
111+
ccw::CommonConvWrapper = _getCCW(fcto);
112+
N::Int = 100,
113+
measurement::AbstractVector = sampleFactor(ccw, N),
114+
retries=nothing,
115+
)
116+
if !isnothing(retries)
117+
Base.depwarn(
118+
"approxDeconv kwarg retries is not used",
119+
:approxDeconv,
120+
)
121+
end
122+
# but what if this is a partial factor -- is that important for general cases in deconv?
123+
_setCCWDecisionDimsConv!(ccw, 0)
124+
125+
varsyms = getVariableOrder(fcto)
126+
127+
# TODO assuming vector on only first container in measurement::Tuple # TBD How should user dispatch fancy tuple measurements on deconv.
128+
129+
# NOTE
130+
# build a lambda that incorporates the multihypo selections
131+
# deconv has to solve for the best matching for particles
132+
# FIXME This does not incorporate multihypo, Apply hyporecipe to full variable order list. But remember hyporecipe assignment must be found (NPhard)
133+
hyporecipe = _prepareHypoRecipe!(nothing, N, 0, length(varsyms))
134+
# only doing the current active hypo
135+
@assert hyporecipe.activehypo[2][1] == 1 "deconv was expecting hypothesis nr == (1, 1:d)"
136+
137+
# get measurement dimension
138+
zDim = _getZDim(fcto)
139+
islen1 = zDim == 1
140+
141+
#make a copy of the original measurement before mutating it
142+
sampled_meas = deepcopy(measurement)
143+
144+
fcttype = getFactorType(fcto)
145+
146+
for idx = 1:N
147+
148+
# TODO must first resolve hypothesis selection before unrolling them -- deferred #1096
149+
resize!(ccw.hyporecipe.activehypo, length(hyporecipe.activehypo[2][2]))
150+
ccw.hyporecipe.activehypo[:] = hyporecipe.activehypo[2][2]
151+
#TODO why is this resize in the loop?
152+
153+
# Create a CalcFactor functor of the correct hypo.
154+
_hypoCalcFactor = _buildHypoCalcFactor(ccw, idx)
155+
156+
ts = _solveLambdaNumericMeas(fcttype, _hypoCalcFactor, measurement[idx], islen1)
157+
measurement[idx] = ts
158+
159+
end
160+
161+
return measurement, sampled_meas
162+
end
163+
118164
"""
119165
$SIGNATURES
120166

src/services/NumericalCalculations.jl

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,6 @@ function (hypoCalcFactor::CalcFactorNormSq)(::Type{ManoptCalcConv}, M::AbstractM
8787
return hypoCalcFactor(CalcConv, p)
8888
end
8989

90-
#TODO untested and unused
91-
# for deconv with the measurement a tangent vector
92-
# function (hypoCalcFactor::CalcFactorNormSq)(M::AbstractManifold, Xc::AbstractVector)
93-
# # M = hypoCalcFactor.manifold # calc factor has factor manifold in not variable that is needed here
94-
# ϵ = getPointIdentity(M)
95-
# X = get_vector(M, ϵ, Xc, DefaultOrthogonalBasis())
96-
# return hypoCalcFactor(CalcDeconv, X)
97-
# end
98-
9990
function _solveLambdaNumeric(
10091
fcttype::Union{F, <:Mixture{N_, F, S, T}},
10192
hypoCalcFactor,
@@ -109,7 +100,6 @@ function _solveLambdaNumeric(
109100
# the variable is a manifold point, we are working on the tangent plane in optim for now.
110101
#
111102
#TODO this is not general to all manifolds, should work for lie groups.
112-
# ϵ = identity_element(M, u0)
113103
ϵ = getPointIdentity(variableType)
114104

115105
X0c = zero(MVector{getDimension(M),Float64})
@@ -142,34 +132,43 @@ function _solveLambdaNumeric(
142132
return exp(M, ϵ, hat(M, ϵ, r.minimizer))
143133
end
144134

135+
## deconvolution with calcfactor wip
136+
struct CalcDeconv end
137+
138+
function (cf::CalcFactorNormSq)(::Type{CalcDeconv}, meas)
139+
res = cf(meas, map(vvh -> _getindex_anyn(vvh, cf._sampleIdx), cf._legacyParams)...)
140+
return sum(x->x^2, res)
141+
end
142+
143+
# for deconv with the measurement a tangent vector, can dispatch for other measurement types.
144+
function (hypoCalcFactor::CalcFactorNormSq)(::Type{CalcDeconv}, M::AbstractManifold, Xc::AbstractVector)
145+
ϵ = getPointIdentity(M)
146+
X = get_vector(M, ϵ, Xc, DefaultOrthogonalBasis())
147+
return hypoCalcFactor(CalcDeconv, X)
148+
end
145149

146-
#TODO Consolidate with _solveLambdaNumeric, see #1374
147-
#TODO _solveLambdaNumericMeas assumes a measurement is always a tangent vector, confirm.
150+
# NOTE Optim.jl version that assumes measurement is on the tangent
151+
# TODO test / dev for n-ary factor deconv
152+
# TODO Consolidate with _solveLambdaNumeric, see #1374
148153
function _solveLambdaNumericMeas(
149154
fcttype::Union{F, <:Mixture{N_, F, S, T}},
150-
objResX::Function,
151-
residual::AbstractVector{<:Real},
155+
hypoCalcFactor,
152156
X0,#::AbstractVector{<:Real},
153-
variableType::InferenceVariable,
154157
islen1::Bool = false,
155158
) where {N_, F <: AbstractManifoldMinimize, S, T}
156159
#
157-
# Assume measurement is on the tangent
158160
M = getManifold(fcttype)
159-
# the variable is a manifold point, we are working on the tangent plane in optim for now.
160161
ϵ = getPointIdentity(M)
161162
X0c = zeros(manifold_dimension(M))
162163
X0c .= vee(M, ϵ, X0)
163164

164-
function cost(Xc)
165-
X = hat(M, ϵ, Xc)
166-
residual = objResX(X)
167-
return sum(residual .^ 2)
168-
end
169-
170165
alg = islen1 ? Optim.BFGS() : Optim.NelderMead()
171166

172-
r = Optim.optimize(cost, X0c, alg)
167+
r = Optim.optimize(
168+
x->hypoCalcFactor(CalcDeconv, M, x),
169+
X0c,
170+
alg
171+
)
173172
if !Optim.converged(r)
174173
@debug "Optim did not converge:" r
175174
end
@@ -374,7 +373,6 @@ end
374373
#
375374

376375
struct CalcConv end
377-
struct CalcDeconv end
378376

379377
_getindex_anyn(vec, n) = begin
380378
len = length(vec)
@@ -397,7 +395,7 @@ function (cf::CalcFactorNormSq)(::Type{CalcConv}, x)
397395
return sum(x->x^2, res)
398396
end
399397

400-
function _buildHypoCalcFactor(ccwl::CommonConvWrapper, smpid::Integer, _slack)
398+
function _buildHypoCalcFactor(ccwl::CommonConvWrapper, smpid::Integer, _slack=nothing)
401399
# build a view to the decision variable memory
402400
varValsHypo = ccwl.varValsAll[][ccwl.hyporecipe.activehypo]
403401
# create calc factor selected hypo and samples

test/testSpecialEuclidean2Mani.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,6 @@ m_θ = map(x->x.x[2][2], meas)
257257
@test isapprox(mean(p_t), mean(m_t), atol=0.3)
258258
@test isapprox(std(p_t), std(m_t), atol=0.3)
259259

260-
##
261260
end
262261

263262

0 commit comments

Comments
 (0)