Skip to content

Commit 525602e

Browse files
authored
Merge pull request #1776 from JuliaRobotics/23Q3/fix/pospointtype
Fix getPointType of Position{N}
2 parents a33dd76 + 21cddbe commit 525602e

13 files changed

+99
-62
lines changed

src/Variables/DefaultVariables.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ DFG.getManifold(::InstanceType{Position{N}}) where {N} = TranslationGroup(N)
1515
function DFG.getDimension(val::InstanceType{Position{N}}) where {N}
1616
return manifold_dimension(getManifold(val))
1717
end
18-
DFG.getPointType(::Type{Position{N}}) where {N} = Vector{Float64}
18+
DFG.getPointType(::Type{Position{N}}) where {N} = SVector{N, Float64}
1919
DFG.getPointIdentity(M_::Type{Position{N}}) where {N} = @SVector(zeros(N)) # identity_element(getManifold(M_), zeros(N))
2020

2121
function Base.convert(

src/entities/ExtFactors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ DevNotes
1919
- FIXME Lots of consolidation and standardization to do, see RoME.jl #244 regarding Manifolds.jl.
2020
- TODO does not yet handle case where a factor spans across two timezones.
2121
"""
22-
struct DERelative{T <: InferenceVariable, P, D} <: AbstractRelativeMinimize
22+
struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # AbstractRelativeMinimize
2323
domain::Type{T}
2424
forwardProblem::P
2525
backwardProblem::P

src/manifolds/services/ManifoldSampling.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ end
2424

2525
# Sampling Distributions
2626
# assumes M is a group and will break for Riemannian, but leaving that enhancement as TODO
27-
function sampleTangent(M::AbstractManifold, z::Distribution, p = getPointIdentity(M), basis::AbstractBasis = DefaultOrthogonalBasis())
27+
function sampleTangent(
28+
M::AbstractManifold,
29+
z::Distribution,
30+
p = getPointIdentity(M),
31+
basis::AbstractBasis = DefaultOrthogonalBasis()
32+
)
2833
return get_vector(M, p, rand(z), basis)
2934
end
3035

src/manifolds/services/ManifoldsExtentions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ end
105105

106106
# fallback
107107
function getPointIdentity(G::GroupManifold, ::Type{T} = Float64) where {T <: Real}
108-
return error("getPointIdentity not implemented on G")
108+
return error("getPointIdentity not implemented on $G")
109109
end
110110

111111
function getPointIdentity(

src/parametric/services/ParametricCSMFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
3838
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
3939
# fill in the variable node data value
4040
logCSM(csmc, "$(csmc.cliq.id) up: updating $v : $val")
41-
vnd.val[1] .= val.val
41+
vnd.val[1] = val.val
4242
#calculate and fill in covariance
4343
#TODO rather broadcast than make new memory
4444
vnd.bw = val.cov
@@ -146,7 +146,7 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
146146
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info)
147147
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
148148
#Update subfg variables
149-
vnd.val[1] .= val.val
149+
vnd.val[1] = val.val
150150
vnd.bw .= val.cov
151151
end
152152
else

src/services/GraphInit.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,11 @@ Notes:
123123
- Special carve out for multihypo cases, see issue 427.
124124
125125
Development Notes:
126-
> Target factor is first (singletons) or second (dim 2 pairwise) variable vertex in `xi`.
127-
* TODO use DFG properly with local operations and DB update at end.
128-
* TODO get faster version of `isInitialized` for database version.
129-
* TODO: Persist this back if we want to here.
126+
- Target factor is first (singletons) or second (dim 2 pairwise) variable vertex in `xi`.
127+
- TODO use DFG properly with local operations and DB update at end.
128+
- TODO get faster version of `isInitialized` for database version.
129+
- TODO: Persist this back if we want to here.
130+
- TODO: init from just partials
130131
"""
131132
function doautoinit!(
132133
dfg::AbstractDFG,

src/services/GraphProductOperations.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ function propagateBelief(
3636
# make sure oldPoints vector has right length
3737
oldBel = getBelief(dfg, destlbl, solveKey)
3838
_pts = getPoints(oldBel, false)
39-
oldPoints = if Npts(oldBel) <= N
40-
_pts[1:N]
41-
else
39+
oldPoints = if Npts(oldBel) < N
4240
nn = N - length(_pts) # should be larger than 0
43-
vcat(_pts, sample(oldBel, nn))
41+
_pts_, = sample(oldBel, nn)
42+
vcat(_pts, _pts_)
43+
else
44+
_pts[1:N]
4445
end
4546

4647
# few more data requirements

src/services/NumericalCalculations.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,9 @@ function _solveCCWNumeric!(
362362
perturb::Real = 1e-10,
363363
) where {N_, F <: AbstractRelative, S, T}
364364
#
365-
365+
366366
#
367367
# thrid = Threads.threadid()
368-
369368
smpid = ccwl.particleidx[]
370369
# cannot Nelder-Mead on 1dim, partial can be 1dim or more but being conservative.
371370
islen1 = length(ccwl.partialDims) == 1 || ccwl.partial
@@ -377,10 +376,11 @@ function _solveCCWNumeric!(
377376
# a separate deepcopy of the destination (aka target) memory is necessary.
378377
# Choosen solution is to splice together ccwl.varValsAll each time, with destination as
379378
# deepcopy but other input variables are just point to the source variable values directly.
380-
if ccwl.partial
381-
target = view(ccwl.varValsAll[][ccwl.varidx[]][smpid], ccwl.partialDims)
379+
target = if ccwl.partial # FIXME likely type-instability on `typeof(target)`
380+
# view(ccwl.varValsAll[][ccwl.varidx[]][smpid], ccwl.partialDims)
381+
ccwl.varValsAll[][ccwl.varidx[]][smpid][ccwl.partialDims]
382382
else
383-
target = ccwl.varValsAll[][ccwl.varidx[]][smpid];
383+
ccwl.varValsAll[][ccwl.varidx[]][smpid]
384384
end
385385
# build the pre-objective function for this sample's hypothesis selection
386386
unrollHypo!, _ = _buildCalcFactorLambdaSample(
@@ -407,12 +407,13 @@ function _solveCCWNumeric!(
407407
# target .+= _perturbIfNecessary(getFactorType(ccwl), length(target), perturb)
408408
sfidx = ccwl.varidx[]
409409
# do the parameter search over defined decision variables using Minimization
410-
if ccwl.partial
411-
X = collect(view(ccwl.varValsAll[][sfidx][smpid], ccwl.partialDims))
412-
else
413-
X = ccwl.varValsAll[][sfidx][smpid][ccwl.partialDims]
414-
end
415-
# X = destVarVals[smpid]#[ccwl.partialDims]
410+
X = ccwl.varValsAll[][sfidx][smpid][ccwl.partialDims]
411+
# X = if ccwl.partial # TODO check for type-instability on `X`
412+
# collect(view(ccwl.varValsAll[][sfidx][smpid], ccwl.partialDims))
413+
# else
414+
# ccwl.varValsAll[][sfidx][smpid][ccwl.partialDims]
415+
# end
416+
# # X = destVarVals[smpid]#[ccwl.partialDims]
416417

417418
retval = _solveLambdaNumeric(
418419
getFactorType(ccwl),
@@ -430,7 +431,13 @@ function _solveCCWNumeric!(
430431

431432
# insert result back at the correct variable element location
432433
if ccwl.partial
433-
ccwl.varValsAll[][sfidx][smpid][ccwl.partialDims] .= retval
434+
# NOTE use workaround of TranslationGroup for coordinates on partial assignment
435+
# FIXME consolidate to Manopt and upgrade to Riemannian (i.e. incl non-groups)
436+
M = getManifold(ccwl) # TranslationGroup(length(ccwl.varValsAll[][sfidx][smpid]))
437+
src = Vector{typeof(retval)}()
438+
push!(src, retval)
439+
setPointPartial!(M, ccwl.varValsAll[][sfidx], M, src, ccwl.partialDims, smpid, 1, true )
440+
# ccwl.varValsAll[][sfidx][smpid][ccwl.partialDims] .= retval
434441
else
435442
# copyto!(ccwl.varValsAll[sfidx][smpid], retval)
436443
copyto!(ccwl.varValsAll[][sfidx][smpid][ccwl.partialDims], retval)

test/runtests.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,30 @@ TEST_GROUP = get(ENV, "IIF_TEST_GROUP", "all")
55
# temporarily moved to start (for debugging)
66
#...
77
if TEST_GROUP in ["all", "tmp_debug_group"]
8-
include("testDERelative.jl")
98
include("testSpecialOrthogonalMani.jl")
109
include("testMultiHypo3Door.jl")
1110
include("priorusetest.jl")
1211
end
1312

1413
if TEST_GROUP in ["all", "basic_functional_group"]
15-
# more frequent stochasic failures from numerics
14+
# more frequent stochasic failures from numerics
1615
include("manifolds/manifolddiff.jl")
1716
include("manifolds/factordiff.jl")
1817
include("testSpecialEuclidean2Mani.jl")
1918
include("testEuclidDistance.jl")
2019

21-
# regular testing
22-
include("testSphereMani.jl")
23-
include("testBasicManifolds.jl")
24-
2520
# start as basic as possible and build from there
2621
include("typeReturnMemRef.jl")
2722
include("testDistributionsGeneric.jl")
28-
include("testHeatmapGridDensity.jl")
2923
include("testCliqSolveDbgUtils.jl")
3024
include("basicGraphsOperations.jl")
3125

26+
# regular testing
27+
include("testSphereMani.jl")
28+
include("testBasicManifolds.jl")
29+
include("testDERelative.jl")
30+
include("testHeatmapGridDensity.jl")
31+
3232
# include("TestModuleFunctions.jl")
3333
include("testCompareVariablesFactors.jl")
3434
include("saveconvertertypes.jl")

test/testBasicParametric.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ end
124124

125125
##
126126

127-
foreach(x->getSolverData(getVariable(fg,x.first),:parametric).val[1] .= x.second.val, pairs(d))
127+
foreach(x->getSolverData(getVariable(fg,x.first),:parametric).val[1] = x.second.val, pairs(d))
128128

129129

130130
# getSolverParams(fg).dbg=true
@@ -186,7 +186,7 @@ foreach(println, d)
186186

187187
##
188188

189-
foreach(x->getSolverData(getVariable(fg,x.first),:parametric).val[1] .= x.second.val, pairs(d))
189+
foreach(x->getSolverData(getVariable(fg,x.first),:parametric).val[1] = x.second.val, pairs(d))
190190

191191
# fg.solverParams.showtree = true
192192
# fg.solverParams.drawtree = true
@@ -228,7 +228,7 @@ for i in 0:10
228228
@test isapprox(d[sym].val[1], i, atol=1e-6)
229229
end
230230

231-
foreach(x->getSolverData(getVariable(fg,x.first),:parametric).val[1] .= x.second.val, pairs(d))
231+
foreach(x->getSolverData(getVariable(fg,x.first),:parametric).val[1] = x.second.val, pairs(d))
232232

233233
# fg.solverParams.showtree = true
234234
# fg.solverParams.drawtree = true

0 commit comments

Comments
 (0)