Skip to content

Commit a24c6aa

Browse files
committed
Only lookup one parameter value
1 parent 72add18 commit a24c6aa

File tree

4 files changed

+71
-65
lines changed

4 files changed

+71
-65
lines changed

interface/CombineMathFuncs.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,11 @@ inline int parametricHistFindBin(const int N_bins, const double* bins, const dou
317317
return -1;
318318
}
319319

320+
inline int parametricHistFindBin(const int N_bins, std::vector<double> const& bins, const double x)
321+
{
322+
return parametricHistFindBin(N_bins, bins.data(), x);
323+
}
324+
320325
inline Double_t parametricHistMorphScale(const double parVal, const int nMorphs,
321326
const double* morphCoeffs,
322327
const double* morphDiffs, const double* morphSums,
@@ -333,17 +338,12 @@ inline Double_t parametricHistMorphScale(const double parVal, const int nMorphs,
333338
return morphScale;
334339
}
335340

336-
inline Double_t parametricHistEvaluate(const double x, const double* parVals, const double* bins,
341+
inline Double_t parametricHistEvaluate(const int bin_i, const double* parVals, const double* bins,
337342
const int N_bins, const double* morphCoeffs, const int nMorphs,
338343
const double* morphDiffs, const double* morphSums,
339344
const double* widths, const double smoothRegion)
340345
{
341-
// Find which bin we're in first
342-
int bin_i = parametricHistFindBin(N_bins, bins, x);
343-
if (bin_i < 0) return 0.0; // Out of range
344-
345-
const double parVal = parVals[bin_i];
346-
346+
if (bin_i < 0) return 0.0;
347347
// Morphing case
348348
if (morphCoeffs != nullptr && nMorphs > 0) {
349349
// morphDiffs and morphSums are flattened arrays of size N_bins * nMorphs
@@ -355,6 +355,7 @@ inline Double_t parametricHistEvaluate(const double x, const double* parVals, co
355355
if (morphSums) {
356356
binMorphSums = morphSums + bin_i * nMorphs;
357357
}
358+
double parVal = parVals[bin_i];
358359
double scale = parametricHistMorphScale(parVal,
359360
nMorphs,
360361
morphCoeffs,
@@ -364,7 +365,7 @@ inline Double_t parametricHistEvaluate(const double x, const double* parVals, co
364365
return (parVal * scale) / widths[bin_i];
365366
}
366367
// No morphing case
367-
return parVal / widths[bin_i];
368+
return parVals[bin_i] / widths[bin_i];
368369
}
369370

370371
inline Double_t parametricMorphFunction(const int j, const double parVal, const bool hasMorphs, const int nMorphs,

interface/RooParametricHist.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class RooParametricHist : public RooAbsPdf {
5757
double getParVal(int bin_i) const;
5858

5959
// Utility functions for data extraction
60-
std::vector<double> getParVals() const;
61-
std::vector<double> getCoeffs() const;
60+
const std::vector<double>& getParVals() const;
61+
const std::vector<double>& getCoeffs() const;
6262
void getFlattenedMorphs(std::vector<double>& diffs_flat, std::vector<double>& sums_flat) const;
6363

6464
protected:
@@ -78,6 +78,10 @@ class RooParametricHist : public RooAbsPdf {
7878
mutable std::vector<std::vector <double> > _diffs;
7979
mutable std::vector<std::vector <double> > _sums;
8080

81+
mutable std::vector<double> pars_vals_; //! Don't serialize me
82+
mutable std::vector<double> coeffs_; //! Don't serialize me
83+
mutable std::vector<double> diffs_flat_; //! Don't serialize me
84+
mutable std::vector<double> sums_flat_; //! Don't serialize me
8185

8286
void initializeBins(const TH1&) const;
8387
//void initializeNorm();

src/CombineCodegenImpl.cxx

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -249,18 +249,24 @@ void RooFit::Experimental::codegenImpl(RooParametricHist& arg, CodegenContext& c
249249
if (arg.hasMorphs()) {
250250
arg.getFlattenedMorphs(diffs_flat, sums_flat);
251251
}
252-
ctx.addResult(&arg,
253-
ctx.buildCall("RooFit::Detail::MathFuncs::parametricHistEvaluate",
254-
arg.observable(),
255-
arg.getPars(),
256-
arg.getBins(),
257-
arg.getNBins(),
258-
arg.getCoeffList(),
259-
static_cast<int>(arg.getCoeffList().size()),
260-
arg.hasMorphs() ? diffs_flat : std::vector<double>{},
261-
arg.hasMorphs() ? sums_flat : std::vector<double>{},
262-
arg.getWidths(),
263-
arg.getSmoothRegion()));
252+
253+
std::stringstream bin_i;
254+
bin_i << ctx.buildCall("RooFit::Detail::MathFuncs::parametricHistFindBin", arg.getNBins(), arg.getBins(), arg.getX());
255+
std::stringstream code;
256+
code << ctx.buildCall("RooFit::Detail::MathFuncs::parametricHistEvaluate",
257+
bin_i.str(),
258+
arg.getPars(),
259+
arg.getBins(),
260+
arg.getNBins(),
261+
arg.getCoeffList(),
262+
static_cast<int>(arg.getCoeffList().size()),
263+
arg.hasMorphs() ? diffs_flat : std::vector<double>{},
264+
arg.hasMorphs() ? sums_flat : std::vector<double>{},
265+
arg.getWidths(),
266+
arg.getSmoothRegion())
267+
+ ";\n";
268+
ctx.addToCodeBody(code.str(), true);
269+
ctx.addResult(&arg, code.str());
264270
}
265271

266272
std::string RooFit::Experimental::codegenIntegralImpl(VerticalInterpPdf& arg,

src/RooParametricHist.cxx

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,15 @@ RooArgList & RooParametricHist::getAllBinVars() const {
9393
}
9494

9595
const double RooParametricHist::quickSum() const {
96-
std::vector<double> pars_vals = getParVals();
97-
std::vector<double> coeffs = getCoeffs();
98-
std::vector<double> diffs_flat;
99-
std::vector<double> sums_flat;
100-
getFlattenedMorphs(diffs_flat, sums_flat);
96+
getParVals();
97+
getCoeffs();
98+
getFlattenedMorphs(diffs_flat_, sums_flat_);
10199

102100
return RooFit::Detail::MathFuncs::parametricHistFullSum(
103-
pars_vals.data(), N_bins, _hasMorphs, _coeffList.getSize(),
104-
coeffs.data(),
105-
_hasMorphs ? diffs_flat.data() : nullptr,
106-
_hasMorphs ? sums_flat.data() : nullptr,
101+
pars_vals_.data(), N_bins, _hasMorphs, _coeffList.getSize(),
102+
coeffs_.data(),
103+
_hasMorphs ? diffs_flat_.data() : nullptr,
104+
_hasMorphs ? sums_flat_.data() : nullptr,
107105
_smoothRegion
108106
);
109107
}
@@ -118,19 +116,17 @@ Int_t RooParametricHist::getAnalyticalIntegral(RooArgSet& allVars, RooArgSet & a
118116
Double_t RooParametricHist::analyticalIntegral(Int_t code, const char* rangeName) const
119117
{
120118
assert(code==1) ;
121-
std::vector<double> pars_vals = getParVals();
122-
std::vector<double> coeffs = getCoeffs();
123-
std::vector<double> diffs_flat;
124-
std::vector<double> sums_flat;
125-
getFlattenedMorphs(diffs_flat, sums_flat);
119+
getParVals();
120+
getCoeffs();
121+
getFlattenedMorphs(diffs_flat_, sums_flat_);
126122
const double* bins_ptr = bins.data();
127123
const double* widths_ptr = widths.data();
128124
return RooFit::Detail::MathFuncs::parametricHistIntegral(
129-
pars_vals.data(), bins_ptr, N_bins,
130-
coeffs.data(),
125+
pars_vals_.data(), bins_ptr, N_bins,
126+
coeffs_.data(),
131127
_coeffList.getSize(),
132-
_hasMorphs ? diffs_flat.data() : nullptr,
133-
_hasMorphs ? sums_flat.data() : nullptr,
128+
_hasMorphs ? diffs_flat_.data() : nullptr,
129+
_hasMorphs ? sums_flat_.data() : nullptr,
134130
widths_ptr,
135131
_smoothRegion,
136132
rangeName,
@@ -165,67 +161,66 @@ void RooParametricHist::addMorphs(RooDataHist &hpdfU, RooDataHist &hpdfD, RooRea
165161

166162
Double_t RooParametricHist::evaluate() const
167163
{
168-
// Find which bin we're in first
164+
// Find which bin we are in and lookup the parameter value
169165
double xVal = getX();
170-
std::vector<double> pars_vals = getParVals();
171-
if (pars_vals.empty() || bins.empty() || widths.empty()) return 0.0;
166+
int bin_i = RooFit::Detail::MathFuncs::parametricHistFindBin(N_bins, bins, xVal);
167+
if (bin_i < 0) return 0.0; // Out of range
168+
pars_vals_[bin_i] = getParVal(bin_i);
169+
if (bins.empty() || widths.empty()) return 0.0;
172170
const double* bins_ptr = bins.data();
173171
const double* widths_ptr = widths.data();
174172
int nMorphs = _coeffList.getSize();
175-
std::vector<double> diffs_flat;
176-
std::vector<double> sums_flat;
177-
getFlattenedMorphs(diffs_flat, sums_flat);
173+
getFlattenedMorphs(diffs_flat_, sums_flat_);
178174

179-
std::vector<double> coeffs;
180175
if (_hasMorphs) {
181-
coeffs = getCoeffs();
176+
getCoeffs();
182177
}
183178

184179
return RooFit::Detail::MathFuncs::parametricHistEvaluate(
185-
xVal,
186-
pars_vals.data(),
180+
bin_i,
181+
pars_vals_.data(),
187182
bins_ptr,
188183
N_bins,
189-
_hasMorphs ? coeffs.data() : nullptr,
184+
_hasMorphs ? coeffs_.data() : nullptr,
190185
nMorphs,
191-
_hasMorphs ? diffs_flat.data() : nullptr,
192-
_hasMorphs ? sums_flat.data() : nullptr,
186+
_hasMorphs ? diffs_flat_.data() : nullptr,
187+
_hasMorphs ? sums_flat_.data() : nullptr,
193188
widths_ptr,
194189
_smoothRegion
195190
);
196-
197191
}
198192

199193
double RooParametricHist::getParVal(int bin_i) const {
200194
return static_cast<RooAbsReal*>(pars.at(bin_i))->getVal();
201195
}
202196

203-
std::vector<double> RooParametricHist::getParVals() const {
204-
std::vector<double> pars_vals;
205-
pars_vals.reserve(pars.getSize());
197+
const std::vector<double>& RooParametricHist::getParVals() const {
198+
pars_vals_.clear();
199+
pars_vals_.reserve(pars.getSize());
206200
for (int i = 0; i < pars.getSize(); ++i) {
207-
pars_vals.push_back(static_cast<RooAbsReal*>(pars.at(i))->getVal());
201+
pars_vals_.push_back(static_cast<RooAbsReal*>(pars.at(i))->getVal());
208202
}
209-
return pars_vals;
203+
return pars_vals_;
210204
}
211205

212-
std::vector<double> RooParametricHist::getCoeffs() const {
213-
std::vector<double> coeffs;
214-
coeffs.reserve(_coeffList.getSize());
206+
const std::vector<double>& RooParametricHist::getCoeffs() const {
207+
coeffs_.clear();
208+
coeffs_.reserve(_coeffList.getSize());
215209
for (int i = 0; i < _coeffList.getSize(); ++i) {
216-
coeffs.push_back(static_cast<RooRealVar*>(_coeffList.at(i))->getVal());
210+
coeffs_.push_back(static_cast<RooRealVar*>(_coeffList.at(i))->getVal());
217211
}
218-
return coeffs;
212+
return coeffs_;
219213
}
220214

221215
void RooParametricHist::getFlattenedMorphs(std::vector<double>& diffs_flat, std::vector<double>& sums_flat) const {
222216
if (!_hasMorphs) return;
217+
// _diffs/_sums are immutable after construction, so only flatten once
218+
if (!diffs_flat.empty()) return;
223219

224220
int nMorphs = _coeffList.getSize();
225221
diffs_flat.reserve(N_bins * nMorphs);
226222
sums_flat.reserve(N_bins * nMorphs);
227223

228-
// Morphs are indexed as [bin][morph], need to flatten to [morph * N_bins + bin]
229224
for (int i = 0; i < N_bins; ++i) {
230225
for (int j = 0; j < nMorphs; ++j) {
231226
diffs_flat.push_back(_diffs[i][j]);

0 commit comments

Comments
 (0)