Skip to content

Commit 819defa

Browse files
HannaOlvhammarguitargeek
authored andcommitted
[RF] Vectorize weight evaluation with RooDataHist::weights()
This PR improves the speed for evaluating weights in `RooHistPdf` and `RooHistFunc` for one dimensional histograms with no interpolation. In the future, `RooDataHist::weights()` can be extended to cover cases with higher dimensions and interpolation. The function `RooDataHist::weights()` was implemented to enable vectorized evaluations of bin weights. In `RooHistPdf` it is implemented using the new function `RooHistPdf::computeBatch()`, which calls `RooDataHist::weights()` in the case of no interpolation and 1D histograms, and `RooAbsReal::computeBatch()` otherwise. In `RooHistFunc::computeBatch`, `RooDataHist::weights()` is called in the case of no interpolation and 1D histograms and is unchanged in the other cases. To calculate the weight, bin indices are stored as a vector using `RooAbsBinning::binNumbers`, which was implemented in root-project#11151.
1 parent 77553d9 commit 819defa

File tree

5 files changed

+46
-1
lines changed

5 files changed

+46
-1
lines changed

roofit/roofitcore/inc/RooDataHist.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class RooDataHist : public RooAbsData, public RooDirItem {
106106
const std::map<const RooAbsArg*, std::pair<double, double> >& ranges,
107107
std::function<double(int)> getBinScale = [](int){ return 1.0; } );
108108

109+
void weights(double* output, RooSpan<double const> xVals, bool correctForBinSize);
109110
/// Return weight of i-th bin. \see getIndex()
110111
double weight(std::size_t i) const { return _wgt[i]; }
111112
double weightFast(const RooArgSet& bin, int intOrder, bool correctForBinSize, bool cdfBoundaries);

roofit/roofitcore/inc/RooHistPdf.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class RooHistPdf : public RooAbsPdf {
102102
std::list<double>* binBoundaries(RooAbsRealLValue& /*obs*/, double /*xlo*/, double /*xhi*/) const override ;
103103
bool isBinnedDistribution(const RooArgSet&) const override { return _intOrder==0 ; }
104104

105+
void computeBatch(cudaStream_t*, double* output, size_t size, RooFit::Detail::DataMap const&) const override;
106+
105107

106108
protected:
107109

roofit/roofitcore/src/RooDataHist.cxx

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,33 @@ RooPlot *RooDataHist::plotOn(RooPlot *frame, PlotOpt o) const
11421142
return RooAbsData::plotOn(frame,o) ;
11431143
}
11441144

1145+
1146+
////////////////////////////////////////////////////////////////////////////////
1147+
/// A vectorized version of RooDataHist::weight for one dimensional histograms
1148+
/// with no interpolation.
1149+
/// \param[out] output An array of weights corresponding the values in xVals.
1150+
/// \param[in] xVals An array of coordinates for which the weights should be
1151+
/// calculated.
1152+
/// \param[in] correctForBinSize Enable the inverse bin volume correction factor.
1153+
1154+
void RooDataHist::weights(double* output, RooSpan<double const> xVals, bool correctForBinSize)
1155+
{
1156+
auto const nEvents = xVals.size();
1157+
RooAbsBinning const& binning = *_lvbins[0];
1158+
1159+
// Reuse the output buffer for bin indices and zero-initialize it
1160+
auto binIndices = reinterpret_cast<int*>(output + nEvents) - nEvents;
1161+
std::fill(binIndices, binIndices + nEvents, 0);
1162+
1163+
binning.binNumbers(xVals.data(), binIndices, nEvents);
1164+
1165+
for (std::size_t i=0; i < nEvents; ++i) {
1166+
auto binIdx = binIndices[i];
1167+
output[i] = correctForBinSize ? _wgt[binIdx] / _binv[binIdx] : _wgt[binIdx];
1168+
}
1169+
}
1170+
1171+
11451172
////////////////////////////////////////////////////////////////////////////////
11461173
/// A faster version of RooDataHist::weight that assumes the passed arguments
11471174
/// are aligned with the histogram variables.
@@ -1243,7 +1270,6 @@ double RooDataHist::weightInterpolated(const RooArgSet& bin, int intOrder, bool
12431270

12441271
double wInt{0} ;
12451272
if (varInfo.nRealVars == 1) {
1246-
12471273
// buffer needs to be 2 x (interpolation order + 1), with the factor 2 for x and y.
12481274
_interpolationBuffer.resize(2 * intOrder + 2);
12491275

roofit/roofitcore/src/RooHistFunc.cxx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ double RooHistFunc::evaluate() const
196196

197197

198198
void RooHistFunc::computeBatch(cudaStream_t*, double* output, size_t size, RooFit::Detail::DataMap const& dataMap) const {
199+
if (_depList.size() == 1 && _intOrder == 0) {
200+
auto xVals = dataMap.at(_depList[0]);
201+
_dataHist->weights(output, xVals, false);
202+
return;
203+
}
204+
199205
std::vector<RooSpan<const double>> inputValues;
200206
for (const auto& obs : _depList) {
201207
auto realObs = dynamic_cast<const RooAbsReal*>(obs);

roofit/roofitcore/src/RooHistPdf.cxx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,17 @@ RooHistPdf::~RooHistPdf()
188188
}
189189

190190

191+
void RooHistPdf::computeBatch(cudaStream_t*, double* output, size_t nEvents, RooFit::Detail::DataMap const& dataMap) const {
191192

193+
// For interpolation and histograms of higher dimension, use base function
194+
if(_pdfObsList.size() > 1 || _intOrder > 0) {
195+
RooAbsReal::computeBatch(nullptr, output, nEvents, dataMap);
196+
return;
197+
}
198+
199+
auto xVals = dataMap.at(_pdfObsList[0]);
200+
_dataHist->weights(output, xVals, true);
201+
}
192202

193203

194204
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)