Skip to content

Commit 565e9aa

Browse files
committed
[RF] Don't deep clone the RooAbsArg in RooAbsReal::getPropagatedError
The `RooAbsReal::getPropagatedError()` function was using some of the most expensive operations in RooFit for larger computation graphs: cloning the model, and figuring out parameters and observables. This was done for no apparent reason, as the `RooAbsReal` is not mutated by `getPropagaterError`. Parameter values are slightly changed for reevaluation, but they are reset right after. A final call to `getVal()` is enough to reset the original state, which is much more efficient than cloning everything. This commit also adds a check if the parameters in the RooAbsReal have the same values as in the fit result (otherwise the logic of getPropagatedError() is broken). This change was motivated by the following forum post: https://root-forum.cern.ch/t/getpropagatederror-method-taking-too-long-to-run/50392
1 parent c3beaf9 commit 565e9aa

File tree

1 file changed

+39
-42
lines changed

1 file changed

+39
-42
lines changed

roofit/roofitcore/src/RooAbsReal.cxx

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,68 +2658,65 @@ RooPlot* RooAbsReal::plotAsymOn(RooPlot *frame, const RooAbsCategoryLValue& asym
26582658
/// \f$ \mathrm{Cov}(mathbf{a},mathbf{a}') \f$ = the covariance matrix from the fit result.
26592659
///
26602660

2661-
double RooAbsReal::getPropagatedError(const RooFitResult &fr, const RooArgSet &nset_in) const
2661+
double RooAbsReal::getPropagatedError(const RooFitResult &fr, const RooArgSet &nset) const
26622662
{
2663+
// Calling getParameters() might be costly, but necessary to get the right
2664+
// parameters in the RooAbsReal. The RooFitResult only stores snapshots.
2665+
RooArgSet allParamsInAbsReal;
2666+
getParameters(&nset, allParamsInAbsReal);
2667+
2668+
// Strip out parameters with zero error
2669+
RooArgList paramList;
2670+
for(auto * rrvInFitResult : static_range_cast<RooRealVar*>(fr.floatParsFinal())) {
2671+
if (rrvInFitResult->getError() > 1e-20) {
2672+
auto * rrvInAbsReal = static_cast<RooRealVar*>(allParamsInAbsReal.find(*rrvInFitResult));
2673+
2674+
if(rrvInAbsReal->getVal() != rrvInFitResult->getVal()) {
2675+
throw std::runtime_error(
2676+
std::string("RooAbsReal::getPropagatedError(): the parameters of the RooAbsReal don't have") +
2677+
"the same values as in the fit result! The logic of getPropagatedError is broken in this case.");
2678+
}
26632679

2664-
// Strip out parameters with zero error
2665-
RooArgList fpf_stripped;
2666-
RooFIter fi = fr.floatParsFinal().fwdIterator();
2667-
RooRealVar *frv;
2668-
while ((frv = (RooRealVar *)fi.next())) {
2669-
if (frv->getError() > 1e-20) {
2670-
fpf_stripped.add(*frv);
2671-
}
2672-
}
2673-
2674-
// Clone self for internal use
2675-
std::unique_ptr<RooAbsReal> cloneFunc{static_cast<RooAbsReal*>(cloneTree())};
2676-
RooArgSet errorParams;
2677-
cloneFunc->getObservables(&fpf_stripped, errorParams);
2678-
2679-
RooArgSet nset;
2680-
if (nset_in.empty()) {
2681-
cloneFunc->getParameters(&errorParams, nset);
2682-
} else {
2683-
cloneFunc->getObservables(&nset_in, nset);
2684-
}
2685-
2686-
// Make list of parameter instances of cloneFunc in order of error matrix
2687-
RooArgList paramList;
2688-
const RooArgList &fpf = fpf_stripped;
2689-
vector<int> fpf_idx;
2690-
for (Int_t i = 0; i < fpf.getSize(); i++) {
2691-
RooAbsArg *par = errorParams.find(fpf[i].GetName());
2692-
if (par) {
2693-
paramList.add(*par);
2694-
fpf_idx.push_back(i);
2695-
}
2680+
paramList.add(*rrvInAbsReal);
2681+
}
26962682
}
26972683

2698-
vector<double> plusVar, minusVar ;
2684+
std::vector<double> plusVar, minusVar ;
2685+
plusVar.reserve(paramList.size());
2686+
minusVar.reserve(paramList.size());
26992687

27002688
// Create vector of plus,minus variations for each parameter
2701-
TMatrixDSym V(paramList.getSize()==fr.floatParsFinal().getSize()?
2702-
fr.covarianceMatrix():
2689+
TMatrixDSym V(paramList.size() == fr.floatParsFinal().size() ?
2690+
fr.covarianceMatrix() :
27032691
fr.reducedCovarianceMatrix(paramList)) ;
27042692

27052693
for (Int_t ivar=0 ; ivar<paramList.getSize() ; ivar++) {
27062694

2707-
RooRealVar& rrv = (RooRealVar&)fpf[fpf_idx[ivar]] ;
2695+
auto& rrv = static_cast<RooRealVar&>(paramList[ivar]);
27082696

27092697
double cenVal = rrv.getVal() ;
27102698
double errVal = sqrt(V(ivar,ivar)) ;
27112699

27122700
// Make Plus variation
2713-
((RooRealVar*)paramList.at(ivar))->setVal(cenVal+errVal) ;
2714-
plusVar.push_back(cloneFunc->getVal(nset)) ;
2701+
rrv.setVal(cenVal+errVal) ;
2702+
plusVar.push_back(getVal(nset)) ;
27152703

27162704
// Make Minus variation
2717-
((RooRealVar*)paramList.at(ivar))->setVal(cenVal-errVal) ;
2718-
minusVar.push_back(cloneFunc->getVal(nset)) ;
2705+
rrv.setVal(cenVal-errVal) ;
2706+
minusVar.push_back(getVal(nset)) ;
27192707

2720-
((RooRealVar*)paramList.at(ivar))->setVal(cenVal) ;
2708+
rrv.setVal(cenVal) ;
27212709
}
27222710

2711+
// Re-evaluate this RooAbsReal with the central parameters just to be
2712+
// extra-safe that a call to `getPropagatedError()` doesn't change any state.
2713+
// It should not be necessarry because thanks to the dirty flag propagation
2714+
// the RooAbsReal is re-evaluated anyway the next time getVal() is called.
2715+
// Still there are imaginable corner cases where it would not be triggered,
2716+
// for example if the user changes the RooFit operation more after the error
2717+
// propagation.
2718+
getVal(nset);
2719+
27232720
TMatrixDSym C(paramList.getSize()) ;
27242721
vector<double> errVec(paramList.getSize()) ;
27252722
for (int i=0 ; i<paramList.getSize() ; i++) {

0 commit comments

Comments
 (0)