Skip to content

Commit 037c0c4

Browse files
committed
attemptGradients on true only
1 parent 7068b44 commit 037c0c4

File tree

3 files changed

+24
-17
lines changed

3 files changed

+24
-17
lines changed

src/Serialization/services/DispatchPackedConversions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ function reconstFactorData(
5050
certainhypo = packed.certainhypo,
5151
inflation = packed.inflation,
5252
userCache,
53+
attemptGradients = getSolverParams(dfg).attemptGradients,
5354
# Block recursion if NoSolverParams or if set to not attempt gradients.
5455
_blockRecursion=
5556
getSolverParams(dfg) isa NoSolverParams ||

src/services/CalcFactor.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ function _prepCCW(
366366
inflation::Real = 0.0,
367367
solveKey::Symbol = :default,
368368
_blockRecursion::Bool = false,
369+
attemptGradients::Bool = true,
369370
userCache::CT = nothing,
370371
) where {T <: AbstractFactor, CT}
371372
#
@@ -416,14 +417,18 @@ function _prepCCW(
416417
varTypes = getVariableType.(fullvariables)
417418

418419
# as per struct CommonConvWrapper
419-
gradients = attemptGradientPrep(
420-
varTypes,
421-
usrfnc,
422-
_varValsAll,
423-
multihypo,
424-
meas_single,
425-
_blockRecursion,
426-
)
420+
_gradients = if attemptGradients
421+
attemptGradientPrep(
422+
varTypes,
423+
usrfnc,
424+
_varValsAll,
425+
multihypo,
426+
meas_single,
427+
_blockRecursion,
428+
)
429+
else
430+
nothing
431+
end
427432

428433
# variable Types
429434
pttypes = getVariableType.(Xi) .|> getPointType
@@ -432,21 +437,21 @@ function _prepCCW(
432437
@warn "_prepCCW PointType is not concrete $PointType" maxlog=50
433438
end
434439

435-
return CommonConvWrapper(
436-
usrfnc,
440+
# PointType[],
441+
return CommonConvWrapper(;
442+
usrfnc! = usrfnc,
437443
fullvariables,
438-
_varValsAll,
439-
PointType[];
440-
userCache, # should be higher in args list
441-
manifold, # should be higher in args list
444+
varValsAll = _varValsAll,
445+
dummyCache = userCache,
446+
manifold,
442447
partialDims,
443448
partial,
444-
nullhypo,
445-
inflation,
449+
nullhypo = float(nullhypo),
450+
inflation = float(inflation),
446451
hypotheses = multihypo,
447452
certainhypo,
448453
measurement,
449-
gradients,
454+
_gradients,
450455
)
451456
end
452457

src/services/FactorGraph.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ function getDefaultFactorData(
733733
multihypo = mhcat,
734734
nullhypo = nh,
735735
inflation,
736+
attemptGradients = getSolverParams(dfg).attemptGradients,
736737
_blockRecursion,
737738
userCache,
738739
)

0 commit comments

Comments
 (0)