11# utilities for calculating the gradient over factors
22
33export getCoordSizes
4+ export checkGradientsToleranceMask, calcPerturbationFromVariable
45
56
67# T_pt_args[:] = [(T1::Type{<:InferenceVariable}, point1); ...]
@@ -96,7 +97,8 @@ function FactorGradientsCached!(fct::Union{<:AbstractRelativeMinimize, <:Abstrac
9697 pts,
9798 tfg,
9899 λ_fncs,
99- λ_sizes )
100+ λ_sizes,
101+ h )
100102end
101103
102104
@@ -148,4 +150,105 @@ function (fgc::FactorGradientsCached!)(meas_pts...)
148150
149151 # return newly calculated gradients
150152 return fgc. cached_gradients
151- end
153+ end
154+
155+ """
156+ $SIGNATURES
157+
158+ Return a mask of same size as gradients matrix `J`, indicating which elements are above the expected sensitivity threshold `tol`.
159+
160+ Notes
161+ - Threshold accuracy depends on two parts,
162+ - Numerical gradient perturbation size `fgc._h`,
163+ - Accuracy tolerance to which the factor residual is computed (not controlled here)
164+ """
165+ function checkGradientsToleranceMask ( fgc:: FactorGradientsCached! ,
166+ J:: AbstractArray = fgc. cached_gradients;
167+ tol:: Real = 0.02 * fgc. _h )
168+ #
169+ # ignore anything 10 times smaller than numerical gradient delta used
170+ # NOTE this ignores the factor residual solve accuracy
171+ return tol* fgc. _h .< abs .(J)
172+ end
173+
174+ """
175+ $SIGNATURES
176+
177+ Return a tuple of infoPerCoord vectors that result from input `fromVar::Int`'s `infoPerCoord`.
178+ For example, a binary `LinearRelative` factor has a one to one influence from the input to the one other variable.
179+
180+ Notes
181+
182+ - Assumes the gradients in `fgc` are up to date -- if not, first run `fgc(measurement..., pts...)`.
183+ - `tol` does not recalculate the gradients to a new tolerance, instead uses the cached value in `fgc` to predict accuracy.
184+
185+ Example
186+
187+ ```julia
188+ # setup
189+ fct = LinearRelative(MvNormal([10;0],[1 0; 0 1]))
190+ measurement = ([10.0;0.0],)
191+ varTypes = (ContinuousEuclid{2}, ContinuousEuclid{2})
192+ pts = ([0;0.0], [9.5;0])
193+
194+ # create the gradients functor object
195+ fgc = FactorGradientsCached!(fct, varTypes, measurement, pts);
196+ # must first update the cached gradients
197+ fgc(measurement..., pts...)
198+
199+ # check the perturbation influence through gradients on factor
200+ ret = calcPerturbationFromVariable(fgc, 1, [1;1])
201+
202+ @assert isapprox(ret[2], [1;1])
203+ ```
204+
205+ DevNotes
206+ - FIXME Support n-ary source factors by extending `fromVar` to more than just one.
207+
208+ Related
209+
210+ [`FactorGradientsCached!`](@ref), [`checkGradientsToleranceMask`](@ref)
211+ """
212+ function calcPerturbationFromVariable (fgc:: FactorGradientsCached! ,
213+ fromVar:: Int ,
214+ infoPerCoord:: AbstractVector ;
215+ tol:: Real = 0.02 * fgc. _h )
216+ #
217+ blkszs = getCoordSizes (fgc)
218+ @assert blkszs[fromVar] == length (infoPerCoord) " Expecting incoming length(infoPerCoord) to equal the block size for variable $fromVar , as per factor used to construct the FactorGradientsCached!: $(getFactorType (fgc. dfgfct)) "
219+ # assume projection through pp-factor from first to second variable
220+ # ipc values from first variable belief, and zero for second
221+ ipcAll = zeros (sum (blkszs))
222+
223+ # nextVar = minimum([fromVar+1; length(blkszs)])
224+ curr_b = sum (blkszs[1 : (fromVar- 1 )]) + 1
225+ curr_e = sum (blkszs[1 : fromVar])
226+ ipcAll[curr_b: curr_e] .= infoPerCoord
227+
228+ # clamp gradients below numerical solver resolution
229+ mask = checkGradientsToleranceMask (fgc, tol= tol)
230+ J = fgc. cached_gradients
231+ _J = zeros (size (J)... )
232+ _J[mask] .= J[mask]
233+
234+ # calculate the gradient influence on other variables
235+ ipc_pert = _J * ipcAll
236+
237+ # round up over numerical solution tolerance
238+ dig = floor (Int, log10 (1 / tol))
239+ ipc_pert .= round .(ipc_pert, digits= dig)
240+
241+ # slice the result
242+ ipcBlk = []
243+ for (i, sz) in enumerate (blkszs)
244+ curr_b = sum (blkszs[1 : (i- 1 )]) + 1
245+ curr_e = sum (blkszs[1 : i])
246+ blk_ = view (ipc_pert, curr_b: curr_e)
247+ push! (ipcBlk, blk_)
248+ end
249+
250+ return tuple (ipcBlk... )
251+ end
252+
253+
254+ #
0 commit comments