Skip to content

Commit e962582

Browse files
committed
ccw.varValsAll is Ref
1 parent bb851a7 commit e962582

9 files changed

+82
-36
lines changed

src/Deprecated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ function Base.getproperty(ccw::CommonConvWrapper, f::Symbol)
257257
# return SingleThreaded
258258
elseif f == :params
259259
error("CommonConvWrapper.params is deprecated, use .varValsAll instead")
260-
return ccw.varValsAll
260+
return ccw.varValsAll[]
261261
elseif f == :vartypes
262262
@warn "CommonConvWrapper.vartypes is deprecated, use typeof.(getVariableType.(ccw.fullvariables) instead" maxlog=3
263263
return typeof.(getVariableType.(ccw.fullvariables))

src/Factors/LinearRelative.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ getDimension(::InstanceType{LinearRelative{N}}) where {N} = N
4141
# new and simplified interface for both nonparametric and parametric
4242
function (s::CalcFactor{<:LinearRelative})(z, x1, x2)
4343
# TODO convert to distance(distance(x2,x1),z) # or use dispatch on `-` -- what to do about `.-`
44+
# if s._sampleIdx < 5
45+
# @info "LinearRelative" s._sampleIdx "$z" "$x1" "$x2" s.solvefor getLabel.(s.fullvariables)
46+
# @info "in variables" pointer(getVal(s.fullvariables[s.solvefor])) getVal(s.fullvariables[s.solvefor])[1]
47+
# end
4448
return z .- (x2 .- x1)
4549
end
4650

src/entities/FactorOperationalMemory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Related
8585
Base.@kwdef struct CommonConvWrapper{
8686
T <: AbstractFactor,
8787
VT <: Tuple,
88-
TP <: Tuple,
88+
TP <: Base.RefValue{<:Tuple},
8989
CT,
9090
AM <: AbstractManifold,
9191
HR <: HypoRecipeCompute,

src/services/CalcFactor.jl

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function CalcFactor(
2626
ccwl::CommonConvWrapper;
2727
factor = ccwl.usrfnc!,
2828
_sampleIdx = 0,
29-
_legacyParams = ccwl.varValsAll,
29+
_legacyParams = ccwl.varValsAll[],
3030
_allowThreads = true,
3131
cache = ccwl.dummyCache,
3232
fullvariables = ccwl.fullvariables,
@@ -211,6 +211,14 @@ function _resizePointsVector!(
211211
return vecP
212212
end
213213

214+
function _checkVarValPointers(dfg::AbstractDFG, fclb::Symbol)
215+
vars = getVariable.(dfg, getVariableOrder(dfg,fclb))
216+
ptrsV = pointer.(getVal.(vars))
217+
ccw = _getCCW(dfg, fclb)
218+
ptrsC = pointer.(ccw.varValsAll[])
219+
ptrsV, ptrsC
220+
end
221+
214222
"""
215223
$(SIGNATURES)
216224
@@ -228,28 +236,33 @@ Notes
228236
- `P = getPointType(<:InferenceVariable)`
229237
"""
230238
function _createVarValsAll(
231-
Xi::Vector{<:DFGVariable};
239+
variables::AbstractVector{<:DFGVariable};
232240
solveKey::Symbol = :default,
233241
)
234242
#
235243
# Note, NamedTuple once upon a time created way too much recompile load on repeat solves, #1564
236-
varValsAll = map(var_i->getVal(var_i; solveKey), tuple(Xi...))
244+
varValsAll = map(var_i->getVal(var_i; solveKey), tuple(variables...))
245+
246+
for (i,vv) in enumerate(varValsAll)
247+
@assert pointer(vv) == pointer(getVal(variables[i]; solveKey)) "Developer check that ccw.varValsAll pointers go to same memory as getVal(variable)"
248+
end
237249

238250
# how many points
239251
LEN = length.(varValsAll)
240252
maxlen = maximum(LEN)
241253
# NOTE, forcing maxlen to N results in errors (see test/testVariousNSolveSize.jl) see #105
242254
# maxlen = N == 0 ? maxlen : N
243255

244-
# allow each variable to have a different number of points, which is resized during compute here
245-
# resample variables with too few kernels (manifolds points)
246-
SAMP = LEN .< maxlen
247-
for i = 1:length(Xi)
248-
if SAMP[i]
249-
Pr = getBelief(Xi[i], solveKey)
250-
_resizePointsVector!(varValsAll[i], Pr, maxlen)
251-
end
252-
end
256+
# NOTE resize! moves the pointer!!!!!!
257+
# # allow each variable to have a different number of points, which is resized during compute here
258+
# # resample variables with too few kernels (manifolds points)
259+
# SAMP = LEN .< maxlen
260+
# for i = 1:length(variables)
261+
# if SAMP[i]
262+
# Pr = getBelief(variables[i], solveKey)
263+
# _resizePointsVector!(varValsAll[i], Pr, maxlen)
264+
# end
265+
# end
253266

254267
# TODO --rather define reusable memory for the proposal
255268
# we are generating a proposal distribution, not direct replacement for existing memory and hence the deepcopy.
@@ -429,7 +442,7 @@ function _createCCW(
429442
return CommonConvWrapper(;
430443
usrfnc! = usrfnc,
431444
fullvariables,
432-
varValsAll = _varValsAll,
445+
varValsAll = Ref(_varValsAll),
433446
dummyCache = userCache,
434447
manifold,
435448
partialDims,
@@ -498,8 +511,9 @@ function _beforeSolveCCW!(
498511
# set the 'solvefor' variable index -- i.e. which connected variable of the factor is being computed in this convolution.
499512
# ccwl.varidx[] = findfirst(==(solvefor), getLabel.(variables))
500513
# everybody use maxlen number of points in belief function estimation
501-
maxlen = maximum((N, length.(ccwl.varValsAll)...,))
514+
maxlen = maximum((N, length.(ccwl.varValsAll[])...,))
502515

516+
ccwl.varValsAll[] = map(s->getVal(s; solveKey), tuple(variables...))
503517
## PLAN B, make deep copy of ccwl.varValsAll[ccwl.varidx[]] just before the numerical solve
504518

505519
# maxlen, ccwl.varidx[] = _updateParamVec(variables, solvefor, ccwl.varValsAll, N; solveKey)
@@ -558,7 +572,7 @@ function _beforeSolveCCW!(
558572
_setCCWDecisionDimsConv!(ccwl, getDimension(getVariableType(Xi[ccwl.varidx[]])))
559573

560574
solveForPts = getVal(Xi[ccwl.varidx[]]; solveKey)
561-
maxlen = maximum([N; length(solveForPts); length(ccwl.varValsAll[ccwl.varidx[]])]) # calcZDim(ccwl); length(measurement[1])
575+
maxlen = maximum([N; length(solveForPts); length(ccwl.varValsAll[][ccwl.varidx[]])]) # calcZDim(ccwl); length(measurement[1])
562576

563577
# FIXME do not divert Mixture for sampling
564578
# update ccwl.measurement values

src/services/EvalFactor.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function calcVariableDistanceExpectedFractional(
4242
sfidx::Integer,
4343
certainidx::AbstractVector{<:Integer};
4444
kappa::Real = 3.0,
45-
readonlyVarVals = ccwl.varValsAll[sfidx]
45+
readonlyVarVals = ccwl.varValsAll[][sfidx]
4646
)
4747
#
4848
varTypes = getVariableType.(ccwl.fullvariables)
@@ -54,14 +54,14 @@ function calcVariableDistanceExpectedFractional(
5454

5555
# get mean of all fractional variables
5656
# ccwl.params::Vector{Vector{P}}
57-
uncertainidx = setdiff(1:length(ccwl.varValsAll), certainidx)
57+
uncertainidx = setdiff(1:length(ccwl.varValsAll[]), certainidx)
5858
dists = zeros(length(uncertainidx) + length(certainidx))
5959

6060
dims = manifold_dimension(getManifold(varTypes[sfidx]))
6161

6262
uncMeans = zeros(dims, length(uncertainidx))
6363
for (count, i) in enumerate(uncertainidx)
64-
u = mean(getManifold(varTypes[i]), ccwl.varValsAll[i])
64+
u = mean(getManifold(varTypes[i]), ccwl.varValsAll[][i])
6565
uncMeans[:, count] .= getCoordinates(varTypes[i], u)
6666
end
6767
count = 0
@@ -79,7 +79,7 @@ function calcVariableDistanceExpectedFractional(
7979
# also check distance to certainidx for general scale reference (workaround heuristic)
8080
for cidx in certainidx
8181
count += 1
82-
cerMeanPnt = mean(getManifold(varTypes[cidx]), ccwl.varValsAll[cidx])
82+
cerMeanPnt = mean(getManifold(varTypes[cidx]), ccwl.varValsAll[][cidx])
8383
cerMean = getCoordinates(varTypes[cidx], cerMeanPnt)
8484
dists[count] = norm(refMean[1:dims] - cerMean[1:dims])
8585
end
@@ -145,7 +145,7 @@ function computeAcrossHypothesis!(
145145
sfidx::Int,
146146
maxlen::Int,
147147
mani::ManifoldsBase.AbstractManifold; # maniAddOps::Tuple;
148-
destinationVarVals = deepcopy(ccwl.varValsAll[sfidx]),
148+
destinationVarVals = deepcopy(ccwl.varValsAll[][sfidx]),
149149
spreadNH::Real = 5.0,
150150
inflateCycles::Int = 3,
151151
skipSolve::Bool = false,
@@ -263,7 +263,7 @@ function _calcIPCRelative(
263263
sfidx_active = sum(active_mask[1:sfidx])
264264

265265
# build a view to the decision variable memory
266-
activeParams = view(ccwl.varValsAll, activeids)
266+
activeParams = view(ccwl.varValsAll[], activeids)
267267
activeVars = Xi[active_mask]
268268

269269
# assume gradients are just done for the first sample values
@@ -320,8 +320,8 @@ function evalPotentialSpecific(
320320
needFreshMeasurements::Bool = true, # superceeds over measurement
321321
solveKey::Symbol = :default,
322322
sfidx::Integer=findfirst(==(solvefor), getLabel.(variables)),
323-
destinationVarVals = deepcopy(ccwl.varValsAll[sfidx]),
324-
N::Int = 0 < length(measurement) ? length(measurement) : maximum(Npts.(getBelief.(Xi, solveKey))),
323+
destinationVarVals = deepcopy(ccwl.varValsAll[][sfidx]),
324+
N::Int = 0 < length(measurement) ? length(measurement) : maximum(Npts.(getBelief.(variables, solveKey))),
325325
spreadNH::Real = 3.0,
326326
inflateCycles::Int = 3,
327327
nullSurplus::Real = 0,
@@ -353,7 +353,7 @@ function evalPotentialSpecific(
353353
# addOps, d1, d2, d3 = buildHybridManifoldCallbacks(manis)
354354
mani = getManifold(variables[sfidx])
355355

356-
@assert destinationVarVals !== ccwl.varValsAll[ccwl.varidx[]] "destination of evalPotential for AbstractRelative not be ccwl.varValsAll[sfidx]"
356+
@assert destinationVarVals !== ccwl.varValsAll[][ccwl.varidx[]] "destination of evalPotential for AbstractRelative not be ccwl.varValsAll[sfidx]"
357357
@assert destinationVarVals !== getVal(variables[ccwl.varidx[]]) "destination of evalPotential for AbstractRelative not be variable.VND.val"
358358

359359
# perform the numeric solutions on the indicated elements
@@ -390,6 +390,7 @@ function evalPotentialSpecific(
390390
# return ccwl.varValsAll[sfidx], ipc
391391
end
392392

393+
393394
# TODO `measurement` might not be properly wired up yet
394395
# TODO consider 1051 here to inflate proposals as general behaviour
395396
function evalPotentialSpecific(

src/services/NumericalCalculations.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,9 @@ function _buildCalcFactorLambdaSample(
349349
# DevNotes, also see new `hyporecipe` approach (towards consolidation CCW CPT FMd CF...)
350350

351351
# build a view to the decision variable memory
352-
varValsHypo = ccwl.varValsAll[ccwl.hyporecipe.activehypo]
352+
varValsHypo = ccwl.varValsAll[][ccwl.hyporecipe.activehypo]
353353
# tup = tuple(varParams...)
354-
# nms = keys(ccwl.varValsAll)[cpt_.activehypo]
354+
# nms = keys(ccwl.varValsAll[])[cpt_.activehypo]
355355
# varValsHypo = NamedTuple{nms,typeof(tup)}(tup)
356356

357357
# prepare fmd according to hypo selection
@@ -474,7 +474,6 @@ end
474474
# should only be calling a new arg list according to activehypo at start of particle
475475
# Try calling an existing lambda
476476
# sensitive to which hypo of course , see #1024
477-
# need to shuffle content inside .cpt.fmd as well as .varValsAll accordingly
478477
#
479478

480479
function _solveCCWNumeric!(

test/testCommonConvWrapper.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ pts = approxConv(fg, getFactor(fg, :x0x1f1), :x1)
120120

121121
ccw = IIF._getCCW(fg, :x0x1f1)
122122

123-
ptr_ = ccw.varValsAll[ccw.varidx[]]
123+
ptr_ = ccw.varValsAll[][ccw.varidx[]]
124124
@cast tp1[i,j] := ptr_[j][i]
125125
@test 90.0 < Statistics.mean(tp1) < 110.0
126-
ptr_ = ccw.varValsAll[1]
126+
ptr_ = ccw.varValsAll[][1]
127127
@cast tp2[i,j] := ptr_[j][i]
128128
@test -10.0 < Statistics.mean(tp2) < 10.0
129129

@@ -135,10 +135,10 @@ initVariable!(fg, :x1, [100*ones(1) for _ in 1:100])
135135

136136
pts = approxConv(fg, getFactor(fg, :x0x1f1), :x0)
137137

138-
ptr_ = ccw.varValsAll[1]
138+
ptr_ = ccw.varValsAll[][1]
139139
@cast tp1[i,j] := ptr_[j][i]
140140
@test -10.0 < Statistics.mean(tp1) < 10.0
141-
ptr_ = ccw.varValsAll[2]
141+
ptr_ = ccw.varValsAll[][2]
142142
@cast tp2[i,j] := ptr_[j][i]
143143
@test 90.0 < Statistics.mean(tp2) < 110.0
144144

test/testMixtureLinearConditional.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ f1_ = DFG.unpackFactor(fg_, pf1)
111111
@show typeof(f1)
112112
@show typeof(f1_)
113113

114-
@show typeof(getSolverData(f1).fnc.varValsAll);
115-
@show typeof(getSolverData(f1_).fnc.varValsAll);
114+
@show typeof(getSolverData(f1).fnc.varValsAll[]);
115+
@show typeof(getSolverData(f1_).fnc.varValsAll[]);
116116

117117
@test DFG.compareFactor(f1, f1_, skip=[:components;:labels;:timezone;:zone;:vartypes;:fullvariables;:particleidx;:varidx])
118118

test/testMultiHypo3Door.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ x3 = 40.0
3333
fg = initfg()
3434
getSolverParams(fg).N = n_samples
3535
getSolverParams(fg).gibbsIters = 5
36+
# forcefully work with ccw.varValsAll to check the pointers are pointing to getVal.(variables)
37+
getSolverParams(fg).graphinit = false
3638

3739
# Place strong prior on locations of three "doors"
3840
addVariable!(fg, :l0, ContinuousScalar, N=n_samples)
@@ -54,10 +56,36 @@ addVariable!(fg, :x0, ContinuousScalar, N=n_samples)
5456
# Make first "door" measurement
5557
f1 = addFactor!(fg, [:x0; :l0; :l1; :l2; :l3], LinearRelative(Normal(0, meas_noise)), multihypo=[1.0; (1/4 for _=1:4)...])
5658

59+
# check pointers before init
60+
a,b = IIF._checkVarValPointers(fg, getLabel(f1))
61+
for i in 1:length(a)
62+
@test a[i] == b[i]
63+
end
64+
65+
# do init (and check that the var pointers did not change)
66+
doautoinit!(fg ,:l0)
67+
doautoinit!(fg ,:l1)
68+
doautoinit!(fg ,:l2)
69+
doautoinit!(fg ,:l3)
70+
71+
a_,b_ = IIF._checkVarValPointers(fg, getLabel(f1))
72+
for i in 1:length(a)
73+
@test a_[i] == b_[i]
74+
end
75+
for i in 1:length(a)
76+
@test a[i] == a_[i]
77+
end
78+
for i in 1:length(a)
79+
@test b[i] == b_[i]
80+
end
81+
82+
5783
# make sure approxConv is as expected
5884
@test isInitialized.(fg, [:l0;:l1;:l2;:l3]) |> all
5985
X0 = approxConvBelief(fg, getLabel(f1), :x0)
60-
# smpls = sampleFactor(fg, f1.label,200)
86+
# smpls = sampleFactor(fg, f1.label,10)
87+
88+
##
6189

6290
# should have four equal sized peaks at landmark locations
6391
@test 0.1 < X0([l0])[1]

0 commit comments

Comments
 (0)