Skip to content

Commit db7ff84

Browse files
authored
Merge pull request #1329 from JuliaRobotics/21Q3/fix/testpartials
wip toward ipc via perturbation on relative
2 parents ef41576 + 452dbd0 commit db7ff84

File tree

9 files changed

+232
-30
lines changed

9 files changed

+232
-30
lines changed

src/ApproxConv.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
376376
skipSolve::Bool=false,
377377
_slack=nothing ) where {T <: AbstractFactor}
378378
#
379-
# @info "EVALSPEC" string(measurement) inflateCycles
379+
380380
# Prep computation variables
381381
# NOTE #1025, should FMD be built here...
382382
sfidx, maxlen, mani = prepareCommonConvWrapper!(ccwl, Xi, solvefor, N, needFreshMeasurements=needFreshMeasurements, solveKey=solveKey)
@@ -408,7 +408,13 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
408408
_slack=_slack )
409409
#
410410
# do info per coord
411-
ipc = ones(getDimension(Xi[sfidx]))
411+
ipc = if ccwl._gradients === nothing
412+
ones(getDimension(Xi[sfidx]))
413+
else
414+
# calcPerturbationFromVariable(ccwl._gradients, 2, ipc_)
415+
ones(getDimension(Xi[sfidx]))
416+
end
417+
412418
# return the found points, and info per coord
413419
return ccwl.params[ccwl.varidx], ipc
414420
end

src/FactorGraph.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -631,17 +631,17 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
631631
pttypes = getVariableType.(Xi) .|> getPointType
632632
PointType = 0 < length(pttypes) ? pttypes[1] : Vector{Float64}
633633
# FIXME stop using Any, see #1321
634-
ARR = Vector{Vector{Any}}()
635-
maxlen, sfidx, mani = prepareparamsarray!(ARR, Xi, nothing, 0) # Nothing for init.
634+
varParams = Vector{Vector{Any}}()
635+
maxlen, sfidx, mani = prepareparamsarray!(varParams, Xi, nothing, 0) # Nothing for init.
636636

637637
# standard factor metadata
638638
sflbl = 0==length(Xi) ? :null : getLabel(Xi[end])
639-
fmd = FactorMetadata(Xi, getLabel.(Xi), ARR, sflbl, nothing)
639+
fmd = FactorMetadata(Xi, getLabel.(Xi), varParams, sflbl, nothing)
640640

641641
# create a temporary CalcFactor object for extracting the first sample
642642
# TODO, deprecate this: guess measurement points type
643643
# MeasType = Vector{Float64} # FIXME use `usrfnc` to get this information instead
644-
_cf = CalcFactor( usrfnc, fmd, 0, 1, nothing, ARR) # (Vector{MeasType}(),)
644+
_cf = CalcFactor( usrfnc, fmd, 0, 1, nothing, varParams) # (Vector{MeasType}(),)
645645

646646
# get a measurement sample
647647
meas_single = sampleFactor(_cf, 1)
@@ -658,12 +658,13 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
658658
Int[]
659659
end
660660

661-
varTypes = typeof.(getVariableType.(Xi))
661+
# as per struct CommonConvWrapper
662+
varTypes::Vector{DataType} = typeof.(getVariableType.(Xi))
662663
gradients = nothing
663664
# prepare new cached gradient lambdas (attempt)
664665
# try
665666
# measurement = tuple(((x->x[1]).(meas_single))...)
666-
# pts = tuple(((x->x[1]).(ARR))...)
667+
# pts = tuple(((x->x[1]).(varParams))...)
667668
# gradients = FactorGradientsCached!(usrfnc, varTypes, measurement, pts);
668669
# catch e
669670
# @warn "Unable to create measurements and gradients for $usrfnc during prep of CCW, falling back on no-partial information assumption."
@@ -673,7 +674,7 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
673674
usrfnc,
674675
PointType[],
675676
zdim,
676-
ARR,
677+
varParams,
677678
fmd,
678679
specialzDim = hasfield(T, :zDim),
679680
partial = ispartl,

src/FactorGraphTypes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ function CommonConvWrapper( fnc::T,
300300
perturb=perturb, res=res )).(1:Threads.nthreads()),
301301
inflation,
302302
partialDims, # SVector(Int32.()...)
303-
vartypes,
303+
DataType[vartypes...],
304304
gradients)
305305
end
306306

src/entities/FactorGradients.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ mutable struct FactorGradientsCached!{F <: AbstractRelative,S,M<:Tuple,P,G,L}
7171
# nested-tuple of gradient lambda functions
7272
_λ_fncs::L
7373
_coord_sizes::Vector{Int}
74+
# gradient delta
75+
_h::Float64
7476
end
7577

7678

src/services/FactorGradients.jl

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# utilities for calculating the gradient over factors
22

33
export getCoordSizes
4+
export checkGradientsToleranceMask, calcPerturbationFromVariable
45

56

67
# T_pt_args[:] = [(T1::Type{<:InferenceVariable}, point1); ...]
@@ -96,7 +97,8 @@ function FactorGradientsCached!(fct::Union{<:AbstractRelativeMinimize, <:Abstrac
9697
pts,
9798
tfg,
9899
λ_fncs,
99-
λ_sizes )
100+
λ_sizes,
101+
h )
100102
end
101103

102104

@@ -148,4 +150,105 @@ function (fgc::FactorGradientsCached!)(meas_pts...)
148150

149151
# return newly calculated gradients
150152
return fgc.cached_gradients
151-
end
153+
end
154+
155+
"""
156+
$SIGNATURES
157+
158+
Return a mask of same size as gradients matrix `J`, indicating which elements are above the expected sensitivity threshold `tol`.
159+
160+
Notes
161+
- Threshold accuracy depends on two parts,
162+
- Numerical gradient perturbation size `fgc._h`,
163+
- Accuracy tolerance to which the factor residual is computed (not controlled here)
164+
"""
165+
function checkGradientsToleranceMask( fgc::FactorGradientsCached!,
166+
J::AbstractArray=fgc.cached_gradients;
167+
tol::Real=0.02*fgc._h )
168+
#
169+
# ignore anything 10 times smaller than numerical gradient delta used
170+
# NOTE this ignores the factor residual solve accuracy
171+
return tol*fgc._h .< abs.(J)
172+
end
173+
174+
"""
175+
$SIGNATURES
176+
177+
Return a tuple of infoPerCoord vectors that result from input `fromVar::Int`'s `infoPerCoord`.
178+
For example, a binary `LinearRelative` factor has a one to one influence from the input to the one other variable.
179+
180+
Notes
181+
182+
- Assumes the gradients in `fgc` are up to date -- if not, first run `fgc(measurement..., pts...)`.
183+
- `tol` does not recalculate the gradients to a new tolerance, instead uses the cached value in `fgc` to predict accuracy.
184+
185+
Example
186+
187+
```julia
188+
# setup
189+
fct = LinearRelative(MvNormal([10;0],[1 0; 0 1]))
190+
measurement = ([10.0;0.0],)
191+
varTypes = (ContinuousEuclid{2}, ContinuousEuclid{2})
192+
pts = ([0;0.0], [9.5;0])
193+
194+
# create the gradients functor object
195+
fgc = FactorGradientsCached!(fct, varTypes, measurement, pts);
196+
# must first update the cached gradients
197+
fgc(measurement..., pts...)
198+
199+
# check the perturbation influence through gradients on factor
200+
ret = calcPerturbationFromVariable(fgc, 1, [1;1])
201+
202+
@assert isapprox(ret[2], [1;1])
203+
```
204+
205+
DevNotes
206+
- FIXME Support n-ary source factors by extending `fromVar` to more than just one.
207+
208+
Related
209+
210+
[`FactorGradientsCached!`](@ref), [`checkGradientsToleranceMask`](@ref)
211+
"""
212+
function calcPerturbationFromVariable(fgc::FactorGradientsCached!,
213+
fromVar::Int,
214+
infoPerCoord::AbstractVector;
215+
tol::Real=0.02*fgc._h )
216+
#
217+
blkszs = getCoordSizes(fgc)
218+
@assert blkszs[fromVar] == length(infoPerCoord) "Expecting incoming length(infoPerCoord) to equal the block size for variable $fromVar, as per factor used to construct the FactorGradientsCached!: $(getFactorType(fgc.dfgfct))"
219+
# assume projection through pp-factor from first to second variable
220+
# ipc values from first variable belief, and zero for second
221+
ipcAll = zeros(sum(blkszs))
222+
223+
# nextVar = minimum([fromVar+1; length(blkszs)])
224+
curr_b = sum(blkszs[1:(fromVar-1)]) + 1
225+
curr_e = sum(blkszs[1:fromVar])
226+
ipcAll[curr_b:curr_e] .= infoPerCoord
227+
228+
# clamp gradients below numerical solver resolution
229+
mask = checkGradientsToleranceMask(fgc, tol=tol)
230+
J = fgc.cached_gradients
231+
_J = zeros(size(J)...)
232+
_J[mask] .= J[mask]
233+
234+
# calculate the gradient influence on other variables
235+
ipc_pert = _J * ipcAll
236+
237+
# round up over numerical solution tolerance
238+
dig = floor(Int, log10(1/tol))
239+
ipc_pert .= round.(ipc_pert, digits=dig)
240+
241+
# slice the result
242+
ipcBlk = []
243+
for (i, sz) in enumerate(blkszs)
244+
curr_b = sum(blkszs[1:(i-1)]) + 1
245+
curr_e = sum(blkszs[1:i])
246+
blk_ = view(ipc_pert, curr_b:curr_e)
247+
push!(ipcBlk, blk_)
248+
end
249+
250+
return tuple(ipcBlk...)
251+
end
252+
253+
254+
#

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ using Test
66
# temporarily moved to start (for debugging)
77

88
# @error "MUST RESTORE testpartialconstraint.jl"
9+
include("testFactorGradients.jl")
910
include("testpartialconstraint.jl")
1011
include("testPartialNH.jl")
1112

1213
include("testSphereMani.jl")
1314
include("testSpecialOrthogonalMani.jl")
1415
include("testSpecialEuclidean2Mani.jl")
1516

16-
include("testFactorGradients.jl")
1717

1818
include("testCliqSolveDbgUtils.jl")
1919
include("TestModuleFunctions.jl")

test/testFactorGradients.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using TensorCast
66
using Manifolds
77
using Test
88

9+
# overloading with new dispatch
10+
import IncrementalInference: getSample, getManifold
11+
912
##
1013

1114
@testset "test manual call to gradient lambda utilities" begin
@@ -48,8 +51,82 @@ J = gradFct(measurement..., pts...)
4851

4952
@test norm( J - [0 0 1 0; 0 0 0 1; 1 0 0 0; 0 1 0 0] ) < 1e-4
5053

54+
## check on transmitted info per coords
55+
56+
ret = calcPerturbationFromVariable(gradFct, 1, [1;1])
57+
58+
# the fromVar itself should be zero
59+
@test length(ret[1]) == 2
60+
@test isapprox( ret[1], [0;0], atol=1e-6 )
61+
62+
# the other variable
63+
@test length(ret[2]) == 2
64+
@test isapprox( ret[2], [1;1], atol=1e-6 )
5165

5266
##
5367
end
5468

69+
##
70+
71+
struct TestPartialRelative2D{B <: SamplableBelief} <: IIF.AbstractRelativeMinimize
72+
Z::B
73+
partial::Tuple{Int}
74+
end
75+
# standard helper with partial set
76+
TestPartialRelative2D(z::SamplableBelief) = TestPartialRelative2D(z, (2,))
77+
78+
# imported earlier for overload
79+
getManifold(fnc::TestPartialRelative2D) = TranslationGroup(2)
80+
getSample(cf::CalcFactor{<:TestPartialRelative2D},N=1) = ([rand(cf.factor.Z) for _ in 1:N], )
81+
82+
# currently requires residual to be returned as a tangent vector element
83+
(cf::CalcFactor{<:TestPartialRelative2D})(z, x1, x2) = x2[2:2] - (x1[2:2] + z[1:1])
84+
85+
##
86+
87+
@testset "test a partial, binary relative factor perturbation (a new user factor)" begin
88+
##
89+
90+
91+
tpr = TestPartialRelative2D(Normal(10,1))
92+
93+
measurement = ([10.0;], )
94+
pts = ([0;0.0], [0;10.0])
95+
varTypes = (ContinuousEuclid{2}, ContinuousEuclid{2})
96+
97+
## construct the lambdas
98+
gradients = FactorGradientsCached!(tpr, varTypes, measurement, pts);
99+
100+
## calculate new gradients
101+
J = gradients(measurement..., pts...)
102+
103+
104+
## check on transmitted info per coords
105+
106+
ret = calcPerturbationFromVariable(gradients, 1, [1;1])
107+
108+
# the fromVar itself should be zero
109+
@test length(ret[1]) == 2
110+
@test isapprox( ret[1], [0;0], atol=1e-6 )
111+
112+
# the other variable only affects the second coordinate dimension
113+
@test length(ret[2]) == 2
114+
@test isapprox( ret[2], [0;1], atol=1e-6 )
115+
116+
## check the reverse perturbation
117+
118+
ret = calcPerturbationFromVariable(gradients, 2, [1;1])
119+
120+
# only the first coordinate dimension is affected
121+
@test length(ret[1]) == 2
122+
@test isapprox( ret[1], [0;1], atol=1e-6 )
123+
124+
## the fromVar itself should be zero
125+
@test length(ret[2]) == 2
126+
@test isapprox( ret[2], [0;0], atol=1e-6 )
127+
128+
##
129+
end
130+
131+
55132
#

test/testGradientUtils.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@ using Test
88
## test utility to build a temporary graph
99

1010
fct = EuclidDistance(Normal(10,1))
11-
T_pt_vec = [(ContinuousScalar,ContinuousScalar); ([0;],[9.5;])]
11+
varTypes = (ContinuousScalar,ContinuousScalar);
12+
varPts = ([0;],[9.5;])
1213

1314
##
1415

15-
dfg, _dfgfct = IIF._buildGraphByFactorAndTypes!(fct, T_pt_vec...)
16+
dfg, _dfgfct = IIF._buildGraphByFactorAndTypes!(fct, varTypes, varPts)
1617

1718
@test length(intersect(ls(dfg), [:x1; :x2])) == 2
1819
@test lsf(dfg) == [:x1x2f1;]
1920

2021
## test the evaluation of factor without
2122

22-
B = IIF._evalFactorTemporary!(EuclidDistance(Normal(10,1)), 2, ([10;],), T_pt_vec... );
23+
B = IIF._evalFactorTemporary!(EuclidDistance(Normal(10,1)), 2, ([10;],), varTypes, varPts );
2324

2425
@test_broken B isa Vector{Vector{Float64}}
2526
@test isapprox( B[1], [10.0;], atol=1e-6)

0 commit comments

Comments
 (0)