Skip to content

Commit 0de46d4

Browse files
committed
[RF] Better ParamHistFunc::computeBatch with RooAbsBinning::binNumbers
So far, the `ParamHistFunc` BatchMode implementation was still sub-uptimal because it used the non-vectorized interface of the RooDataHist. Using the new `RooAbsBinning::binNumbers()` function too look up multiple bin indices at once, the implementaiton can be improved and sped. In the three-dimensional many-bin case, the new implementation is a bit more than three times faster than the old one. This should benefit HistFactory fits with many bins. A new `testParamHistFunc` unit test was also introduced to validate the results of a ParamHistFunc both with and without the batch mode, comparing to manually computed reference results.
1 parent 150ada5 commit 0de46d4

File tree

4 files changed

+105
-41
lines changed

4 files changed

+105
-41
lines changed

roofit/histfactory/inc/RooStats/HistFactory/ParamHistFunc.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class ParamHistFunc : public RooAbsReal {
2727
ParamHistFunc() ;
2828
ParamHistFunc(const char *name, const char *title, const RooArgList& vars, const RooArgList& paramSet );
2929
ParamHistFunc(const char *name, const char *title, const RooArgList& vars, const RooArgList& paramSet, const TH1* hist );
30-
~ParamHistFunc() override ;
3130

3231
ParamHistFunc(const ParamHistFunc& other, const char* name = 0);
3332
TObject* clone(const char* newname) const override { return new ParamHistFunc(*this, newname); }
@@ -67,8 +66,6 @@ class ParamHistFunc : public RooAbsReal {
6766

6867
class CacheElem : public RooAbsCacheElement {
6968
public:
70-
CacheElem() {} ;
71-
~CacheElem() override {} ;
7269
RooArgList containedArgs(Action) override {
7370
RooArgList ret(_funcIntList) ;
7471
ret.add(_lowIntList);
@@ -85,7 +82,7 @@ class ParamHistFunc : public RooAbsReal {
8582
RooListProxy _dataVars; ///< The RooRealVars
8683
RooListProxy _paramSet ; ///< interpolation parameters
8784

88-
Int_t _numBins;
85+
Int_t _numBins = 0;
8986
struct NumBins {
9087
NumBins() {}
9188
NumBins(int nx, int ny, int nz) : x{nx}, y{ny}, z{nz}, xy{x*y}, xz{x*z}, yz{y*z}, xyz{xy*z} {}

roofit/histfactory/src/ParamHistFunc.cxx

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ ClassImp(ParamHistFunc);
6060
////////////////////////////////////////////////////////////////////////////////
6161

6262
ParamHistFunc::ParamHistFunc()
63-
: _normIntMgr(this), _numBins(0)
63+
: _normIntMgr(this)
6464
{
6565
_dataSet.removeSelfFromDir(); // files must not delete _dataSet.
6666
}
@@ -84,7 +84,6 @@ ParamHistFunc::ParamHistFunc(const char* name, const char* title,
8484
_normIntMgr(this),
8585
_dataVars("!dataVars","data Vars", this),
8686
_paramSet("!paramSet","bin parameters", this),
87-
_numBins(0),
8887
_dataSet( (std::string(name)+"_dataSet").c_str(), "", vars)
8988
{
9089

@@ -130,7 +129,6 @@ ParamHistFunc::ParamHistFunc(const char* name, const char* title,
130129
// _dataVar("!dataVar","data Var", this, (RooRealVar&) var),
131130
_dataVars("!dataVars","data Vars", this),
132131
_paramSet("!paramSet","bin parameters", this),
133-
_numBins(0),
134132
_dataSet( (std::string(name)+"_dataSet").c_str(), "", vars, Hist)
135133
{
136134

@@ -187,14 +185,6 @@ ParamHistFunc::ParamHistFunc(const ParamHistFunc& other, const char* name) :
187185
}
188186

189187

190-
////////////////////////////////////////////////////////////////////////////////
191-
192-
ParamHistFunc::~ParamHistFunc()
193-
{
194-
;
195-
}
196-
197-
198188
////////////////////////////////////////////////////////////////////////////////
199189
/// Get the parameter associated with the index.
200190
/// The index follows RooDataHist indexing conventions.
@@ -603,35 +593,34 @@ double ParamHistFunc::evaluate() const
603593
/// \param[in,out] evalData Input/output data for evaluating the ParamHistFunc.
604594
/// \param[in] normSet Normalisation set passed on to objects that are serving values to us.
605595
void ParamHistFunc::computeBatch(cudaStream_t*, double* output, size_t size, RooFit::Detail::DataMap const& dataMap) const {
606-
std::vector<double> oldValues;
607-
std::vector<RooSpan<const double>> data;
608-
oldValues.reserve(_dataVars.size());
609-
data.reserve(_dataVars.size());
610-
611-
// Retrieve data for all variables
612-
for (auto arg : _dataVars) {
613-
const auto* var = static_cast<RooRealVar*>(arg);
614-
oldValues.push_back(var->getVal());
615-
data.push_back(dataMap.at(var));
616-
}
617596

618-
// Run computation for each entry in the dataset
619-
for (std::size_t i = 0; i < size; ++i) {
620-
for (unsigned int j = 0; j < _dataVars.size(); ++j) {
621-
assert(i < data[j].size());
622-
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
623-
var.setCachedValue(data[j][i], /*notifyClients=*/false);
624-
}
597+
auto const& n = _numBinsPerDim;
598+
// check if _numBins needs to be filled
599+
if(n.x == 0) {
600+
_numBinsPerDim = getNumBinsPerDim(_dataVars);
601+
}
625602

626-
const auto index = _dataSet.getIndex(_dataVars, /*fast=*/true);
627-
const RooAbsReal& param = getParameter(index);
628-
output[i] = param.getVal();
603+
// Different from the evaluate() funnction that first retrieves the indices
604+
// corresponding to the RooDataHist and then transforms them, we can use the
605+
// right bin multiplicators to begin with.
606+
std::array<int, 3> idxMult{{1, n.x, n.xy}};
607+
608+
// As a working buffer for the bin indices, we use the tail of the output
609+
// buffer. We can't use the same starting pointer, otherwise we would
610+
// overwrite the later bin indices as we fill the output.
611+
auto indexBuffer = reinterpret_cast<int*>(output + size) - size;
612+
std::fill(indexBuffer, indexBuffer + size, 0); // output buffer for bin indices needs to be zero-initialized
613+
614+
// Use the vectorized RooAbsBinning::binNumbers() to update the total bin
615+
// index for each dimension, using the `coef` parameter to multiply with the
616+
// right index multiplication factor for each dimension.
617+
for (std::size_t iVar = 0; iVar < _dataVars.size(); ++iVar) {
618+
_dataSet.getBinnings()[iVar]->binNumbers(dataMap.at(&_dataVars[iVar]).data(), indexBuffer, size, idxMult[iVar]);
629619
}
630620

631-
// Restore old values
632-
for (unsigned int j = 0; j < _dataVars.size(); ++j) {
633-
auto& var = static_cast<RooRealVar&>(_dataVars[j]);
634-
var.setCachedValue(oldValues[j], /*notifyClients=*/false);
621+
// Finally, look up the parameters and get their values to fill the output buffer
622+
for (std::size_t i = 0; i < size; ++i) {
623+
output[i] = static_cast<RooAbsReal const&>(_paramSet[indexBuffer[i]]).getVal();
635624
}
636625
}
637626

roofit/histfactory/test/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@
77
# @author Stephan Hageboeck CERN, 2019
88

99
ROOT_ADD_GTEST(testHistFactory testHistFactory.cxx
10-
LIBRARIES RooFitCommon RooFitCore RooFit RooStats HistFactory RooBatchCompute
10+
LIBRARIES RooFitCore RooFit RooStats HistFactory
1111
COPY_TO_BUILDDIR ${CMAKE_CURRENT_SOURCE_DIR}/ref_6.16_example_UsingC_channel1_meas_model.root ${CMAKE_CURRENT_SOURCE_DIR}/ref_6.16_example_UsingC_combined_meas_model.root)
12+
13+
ROOT_ADD_GTEST(testParamHistFunc testParamHistFunc.cxx LIBRARIES RooFitCore HistFactory)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Tests for the ParamHistFunc
2+
// Authors: Jonas Rembser, CERN 08/2022
3+
4+
#include <RooArgSet.h>
5+
#include <RooArgSet.h>
6+
#include <RooConstVar.h>
7+
#include <RooDataSet.h>
8+
#include <RooRandom.h>
9+
#include <RooRealVar.h>
10+
#include <RooStats/HistFactory/ParamHistFunc.h>
11+
12+
#include <gtest/gtest.h>
13+
14+
/// Validate the ParamHistFunc in the n-dimensional case, comparing both the
15+
/// BatchMode and the old implementation results to a manually-compute
16+
/// reference result.
17+
TEST(ParamHistFunc, ValidateND)
18+
{
19+
20+
// Define the number of bins in each dimension
21+
std::array<std::size_t, 3> nbins{{11, 32, 8}};
22+
std::size_t nbinstot = nbins[0] * nbins[1] * nbins[2];
23+
24+
// The bin mltiplication factors to look up the right parameters
25+
std::array<std::size_t, 3> binMult{{1, nbins[0], nbins[0] * nbins[1]}};
26+
27+
// Create the variables and set their bin numbers. The range is tweaked
28+
// such that each integer values falls in a different bin, starting from
29+
// zero.
30+
RooRealVar x{"x", "x", -0.5, nbins[0] - 0.5};
31+
RooRealVar y{"y", "y", -0.5, nbins[1] - 0.5};
32+
RooRealVar z{"z", "z", -0.5, nbins[2] - 0.5};
33+
RooArgSet vars{x, y, z};
34+
for (std::size_t i = 0; i < nbins.size(); ++i) {
35+
static_cast<RooRealVar &>(*vars[i]).setBins(nbins[i]);
36+
}
37+
38+
// Simple set of parameters that just return their index in the parameter
39+
// list
40+
RooArgList params;
41+
for (std::size_t i = 0; i < nbinstot; ++i) {
42+
params.add(RooFit::RooConst(i));
43+
}
44+
45+
ParamHistFunc paramHistFunc{"phf", "phf", vars, params};
46+
47+
std::size_t nEntries = 100;
48+
49+
RooDataSet data{"data", "data", vars};
50+
std::vector<double> resultsRef(nEntries);
51+
std::vector<double> resultsScalar(nEntries);
52+
53+
// Do some things in one go:
54+
// * assing random integer values to each variable in each iteration
55+
// * fill the dataset used for batched evaluation
56+
// * compute the reference result manually
57+
// * compute the result with the ParamHistFunc without BatchMode
58+
for (std::size_t i = 0; i < nEntries; ++i) {
59+
for (std::size_t iVar = 0; iVar < vars.size(); ++iVar) {
60+
auto var = static_cast<RooRealVar *>(vars[iVar]);
61+
var->setVal(int(RooRandom::uniform() * nbins[iVar]));
62+
}
63+
data.add(vars);
64+
resultsRef[i] = binMult[0] * x.getVal() + binMult[1] * y.getVal() + binMult[2] * z.getVal();
65+
resultsScalar[i] = paramHistFunc.getVal();
66+
}
67+
68+
// Get the results in BatchMode using the dataset
69+
auto resultsBatch = paramHistFunc.getValues(data);
70+
71+
// Validate the results
72+
for (std::size_t i = 0; i < nEntries; ++i) {
73+
EXPECT_EQ(int(resultsScalar[i]), int(resultsRef[i])) << "Scalar result is not correct!";
74+
EXPECT_EQ(int(resultsBatch[i]), int(resultsRef[i])) << "BatchMode result is not correct!";
75+
}
76+
}

0 commit comments

Comments
 (0)