Skip to content

Commit 59fd728

Browse files
committed
[RF] Also read other types than int into RooAbsCategory from TTree
For `RooAbsReal`, this was already implemented in [1] with the TreeReadBuffer mechanism. This commit uses the same approach for the RooAbsCategory as well, such that one can read TTree branches of all fundamental types to RooFit categories. Fixes root-project#10278. [1] a18675a, "Add capability to read ULong64_t + more into RooDataSets."
1 parent 5a995a0 commit 59fd728

File tree

6 files changed

+140
-83
lines changed

6 files changed

+140
-83
lines changed

roofit/roofitcore/inc/RooAbsCategory.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@ class TTree;
3333
class RooVectorDataStore;
3434
class Roo1DTable;
3535
class TIterator;
36+
struct TreeReadBuffer; /// A space to attach TBranches
3637

3738
class RooAbsCategory : public RooAbsArg {
3839
public:
3940
/// The type used to denote a specific category state.
4041
using value_type = int;
4142

4243
// Constructors, assignment etc.
43-
RooAbsCategory() { };
44+
RooAbsCategory();
4445
RooAbsCategory(const char *name, const char *title);
4546
RooAbsCategory(const RooAbsCategory& other, const char* name=0) ;
4647
~RooAbsCategory() override;
@@ -215,13 +216,14 @@ class RooAbsCategory : public RooAbsArg {
215216
mutable value_type _currentIndex{std::numeric_limits<int>::min()}; ///< Current category state
216217
std::map<std::string, value_type> _stateNames; ///< Map state names to index numbers. Make sure state names are updated in recomputeShape().
217218
std::vector<std::string> _insertionOrder; ///< Keeps track in which order state numbers have been inserted. Make sure this is updated in recomputeShape().
218-
mutable UChar_t _byteValue{0}; ///<! Transient cache for byte values from tree branches
219219
mutable std::map<value_type, std::unique_ptr<RooCatType, std::function<void(RooCatType*)>> > _legacyStates; ///<! Map holding pointers to RooCatType instances. Only for legacy interface. Don't use if possible.
220-
bool _treeVar{false}; ///< Is this category attached to a tree?
221220

222221
static const decltype(_stateNames)::value_type& invalidCategory();
223222

224-
ClassDefOverride(RooAbsCategory, 3) // Abstract discrete variable
223+
private:
224+
std::unique_ptr<TreeReadBuffer> _treeReadBuffer; //! A buffer for reading values from trees
225+
226+
ClassDefOverride(RooAbsCategory, 4) // Abstract discrete variable
225227
};
226228

227229
#endif

roofit/roofitcore/src/RooAbsCategory.cxx

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ the following replacements should be used:
5252
#include "RooMsgService.h"
5353
#include "RooVectorDataStore.h"
5454
#include "RooFitLegacy/RooAbsCategoryLegacyIterator.h"
55+
#include "TreeReadBuffer.h"
5556

5657
#include "Compression.h"
5758
#include "TString.h"
5859
#include "TTree.h"
5960
#include "TLeaf.h"
6061
#include "TBranch.h"
6162

63+
#include <functional>
6264
#include <memory>
6365

6466
using namespace std;
@@ -72,6 +74,10 @@ const decltype(RooAbsCategory::_stateNames)::value_type& RooAbsCategory::invalid
7274
return invalid;
7375
}
7476

77+
78+
RooAbsCategory::RooAbsCategory() {}
79+
80+
7581
////////////////////////////////////////////////////////////////////////////////
7682
/// Constructor
7783

@@ -90,8 +96,7 @@ RooAbsCategory::RooAbsCategory(const char *name, const char *title) :
9096
RooAbsCategory::RooAbsCategory(const RooAbsCategory& other,const char* name) :
9197
RooAbsArg(other,name), _currentIndex(other._currentIndex),
9298
_stateNames(other._stateNames),
93-
_insertionOrder(other._insertionOrder),
94-
_treeVar(other._treeVar)
99+
_insertionOrder(other._insertionOrder)
95100
{
96101
setValueDirty() ;
97102
setShapeDirty() ;
@@ -439,41 +444,65 @@ void RooAbsCategory::attachToVStore(RooVectorDataStore& vstore)
439444
/// Attach the category index and label as branches to the given
440445
/// TTree. The index field will be attached as integer with name
441446
/// `<name>_idx`. If a branch `<name>` exists, it attaches to this branch.
442-
void RooAbsCategory::attachToTree(TTree& t, Int_t bufSize)
447+
void RooAbsCategory::attachToTree(TTree& tree, Int_t bufSize)
443448
{
444449
// First check if there is an integer branch matching the category name
445450
TString cleanName(cleanBranchName()) ;
446-
TBranch* branch = t.GetBranch(cleanName) ;
451+
TBranch* branch = tree.GetBranch(cleanName) ;
447452
if (!branch) {
448453
cleanName += "_idx";
449-
branch = t.GetBranch(cleanName);
454+
branch = tree.GetBranch(cleanName);
450455
}
451456

452457
if (branch) {
453-
TString typeName(((TLeaf*)branch->GetListOfLeaves()->At(0))->GetTypeName()) ;
454-
if (!typeName.CompareTo("Int_t")) {
455-
// Imported TTree: attach only index field as branch
456-
457-
coutI(DataHandling) << "RooAbsCategory::attachToTree(" << GetName() << ") TTree branch " << GetName()
458-
<< " will be interpreted as category index" << endl ;
459-
460-
t.SetBranchAddress(cleanName, &_currentIndex) ;
461-
setAttribute("INTIDXONLY_TREE_BRANCH",kTRUE) ;
462-
_treeVar = true;
463-
return ;
464-
} else if (!typeName.CompareTo("UChar_t")) {
465-
coutI(DataHandling) << "RooAbsReal::attachToTree(" << GetName() << ") TTree UChar_t branch " << GetName()
466-
<< " will be interpreted as category index" << endl ;
467-
t.SetBranchAddress(cleanName,&_byteValue) ;
468-
setAttribute("UCHARIDXONLY_TREE_BRANCH",kTRUE) ;
469-
_treeVar = true;
458+
TLeaf* leaf = (TLeaf*)branch->GetListOfLeaves()->At(0) ;
459+
460+
// Check that leaf is _not_ an array
461+
Int_t dummy ;
462+
TLeaf* counterLeaf = leaf->GetLeafCounter(dummy) ;
463+
if (counterLeaf) {
464+
coutE(Eval) << "RooAbsCategory::attachToTree(" << GetName() << ") ERROR: TTree branch " << GetName()
465+
<< " is an array and cannot be attached to a RooAbsCategory" << endl ;
470466
return ;
471467
}
468+
469+
TString typeName(leaf->GetTypeName()) ;
470+
471+
472+
// For different type names, store a function to attach
473+
std::map<std::string, std::function<std::unique_ptr<TreeReadBuffer>()>> typeMap {
474+
{"Float_t", [&](){ return createTreeReadBuffer<Float_t >(cleanName, tree); }},
475+
{"Double_t", [&](){ return createTreeReadBuffer<Double_t >(cleanName, tree); }},
476+
{"UChar_t", [&](){ return createTreeReadBuffer<UChar_t >(cleanName, tree); }},
477+
{"Bool_t", [&](){ return createTreeReadBuffer<Bool_t >(cleanName, tree); }},
478+
{"Char_t", [&](){ return createTreeReadBuffer<Char_t >(cleanName, tree); }},
479+
{"UInt_t", [&](){ return createTreeReadBuffer<UInt_t >(cleanName, tree); }},
480+
{"Long64_t", [&](){ return createTreeReadBuffer<Long64_t >(cleanName, tree); }},
481+
{"ULong64_t", [&](){ return createTreeReadBuffer<ULong64_t>(cleanName, tree); }},
482+
{"Short_t", [&](){ return createTreeReadBuffer<Short_t >(cleanName, tree); }},
483+
{"UShort_t", [&](){ return createTreeReadBuffer<UShort_t >(cleanName, tree); }},
484+
};
485+
486+
auto typeDetails = typeMap.find(typeName.Data());
487+
if (typeDetails != typeMap.end()) {
488+
coutI(DataHandling) << "RooAbsCategory::attachToTree(" << GetName() << ") TTree " << typeName << " branch \"" << cleanName
489+
<< "\" will be converted to int." << endl ;
490+
_treeReadBuffer = typeDetails->second();
491+
} else {
492+
_treeReadBuffer = nullptr;
493+
494+
if (!typeName.CompareTo("Int_t")) {
495+
tree.SetBranchAddress(cleanName, &_currentIndex);
496+
}
497+
else {
498+
coutE(InputArguments) << "RooAbsCategory::attachToTree(" << GetName() << ") data type " << typeName << " is not supported." << endl ;
499+
}
500+
}
472501
} else {
473502
TString format(cleanName);
474503
format.Append("/I");
475504
void* ptr = &_currentIndex;
476-
t.Branch(cleanName, ptr, (const Text_t*)format, bufSize);
505+
tree.Branch(cleanName, ptr, (const Text_t*)format, bufSize);
477506
}
478507
}
479508

@@ -533,33 +562,17 @@ void RooAbsCategory::copyCache(const RooAbsArg *source, Bool_t /*valueOnly*/, Bo
533562
auto other = static_cast<const RooAbsCategory*>(source);
534563
assert(dynamic_cast<const RooAbsCategory*>(source));
535564

536-
_currentIndex = other->_currentIndex;
537-
538-
if (setValDirty) {
539-
setValueDirty();
540-
}
541-
542-
if (!_treeVar)
543-
return;
544-
545-
if (source->getAttribute("INTIDXONLY_TREE_BRANCH")) {
546-
// Lookup cat state from other-index because label is missing
547-
if (hasIndex(other->_currentIndex)) {
548-
_currentIndex = other->_currentIndex;
549-
} else {
550-
coutE(DataHandling) << "RooAbsCategory::copyCache(" << GetName() << ") ERROR: index of source arg "
551-
<< source->GetName() << " is invalid (" << other->_currentIndex
552-
<< "), value not updated" << endl;
553-
}
554-
} else if (source->getAttribute("UCHARIDXONLY_TREE_BRANCH")) {
555-
// Lookup cat state from other-index because label is missing
556-
Int_t tmp = static_cast<int>(other->_byteValue);
557-
if (hasIndex(tmp)) {
558-
_currentIndex = tmp;
559-
} else {
560-
coutE(DataHandling) << "RooAbsCategory::copyCache(" << GetName() << ") ERROR: index of source arg "
561-
<< source->GetName() << " is invalid (" << tmp << "), value not updated" << endl;
565+
value_type tmp = other->_treeReadBuffer ? *other->_treeReadBuffer : other->_currentIndex;
566+
// Lookup cat state from other-index because label is missing
567+
if (hasIndex(tmp)) {
568+
_currentIndex = tmp;
569+
if (setValDirty) {
570+
setValueDirty();
562571
}
572+
} else {
573+
coutE(DataHandling) << "RooAbsCategory::copyCache(" << GetName() << ") ERROR: index of source arg "
574+
<< source->GetName() << " is invalid (" << other->_currentIndex
575+
<< "), value not updated" << endl;
563576
}
564577
}
565578

roofit/roofitcore/src/RooAbsReal.cxx

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
#include "RooCachedReal.h"
8585
#include "RooHelpers.h"
8686
#include "RunContext.h"
87+
#include "TreeReadBuffer.h"
8788
#include "ValueChecking.h"
8889

8990
#include "ROOT/StringUtils.hxx"
@@ -3205,12 +3206,6 @@ RooAbsFunc *RooAbsReal::bindVars(const RooArgSet &vars, const RooArgSet* nset, B
32053206

32063207

32073208

3208-
struct TreeReadBuffer {
3209-
virtual ~TreeReadBuffer() = default;
3210-
virtual operator double() = 0;
3211-
};
3212-
3213-
32143209
////////////////////////////////////////////////////////////////////////////////
32153210
/// Copy the cached value of another RooAbsArg to our cache.
32163211
/// Warning: This function just copies the cached values of source,
@@ -3238,30 +3233,6 @@ void RooAbsReal::attachToVStore(RooVectorDataStore& vstore)
32383233
}
32393234

32403235

3241-
namespace {
3242-
/// Helper for reading branches with various types from a TTree, and convert all to double.
3243-
template<typename T>
3244-
struct TypedTreeReadBuffer final : public TreeReadBuffer {
3245-
operator double() override {
3246-
return _value;
3247-
}
3248-
T _value;
3249-
};
3250-
3251-
/// Create a TreeReadBuffer to hold the specified type, and attach to the branch passed as argument.
3252-
/// \tparam T Type of branch to be read.
3253-
/// \param[in] branchName Attach to this branch.
3254-
/// \param[in] tree Tree to attach to.
3255-
template<typename T>
3256-
std::unique_ptr<TreeReadBuffer> createTreeReadBuffer(const TString& branchName, TTree& tree) {
3257-
auto buf = new TypedTreeReadBuffer<T>();
3258-
tree.SetBranchAddress(branchName.Data(), &buf->_value);
3259-
return std::unique_ptr<TreeReadBuffer>(buf);
3260-
}
3261-
3262-
}
3263-
3264-
32653236
////////////////////////////////////////////////////////////////////////////////
32663237
/// Attach object to a branch of given TTree. By default it will
32673238
/// register the internal value cache RooAbsReal::_value as branch
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Project: RooFit
3+
* Authors:
4+
* Stephan Hageboeck, CERN 2020
5+
*
6+
* Copyright (c) 2022, CERN
7+
*
8+
* Redistribution and use in source and binary forms,
9+
* with or without modification, are permitted according to the terms
10+
* listed in LICENSE (http://roofit.sourceforge.net/license.txt)
11+
*/
12+
13+
#ifndef RooFit_TreeReadBuffer_h
14+
#define RooFit_TreeReadBuffer_h
15+
16+
#include <TTree.h>
17+
18+
struct TreeReadBuffer {
19+
virtual ~TreeReadBuffer() = default;
20+
virtual operator double() = 0;
21+
virtual operator int() = 0;
22+
};
23+
24+
/// Helper for reading branches with various types from a TTree, and convert all to double.
25+
template <typename T>
26+
struct TypedTreeReadBuffer final : public TreeReadBuffer {
27+
operator double() override { return _value; }
28+
operator int() override { return _value; }
29+
T _value;
30+
};
31+
32+
/// Create a TreeReadBuffer to hold the specified type, and attach to the branch passed as argument.
33+
/// \tparam T Type of branch to be read.
34+
/// \param[in] branchName Attach to this branch.
35+
/// \param[in] tree Tree to attach to.
36+
template <typename T>
37+
std::unique_ptr<TreeReadBuffer> createTreeReadBuffer(const TString &branchName, TTree &tree)
38+
{
39+
auto buf = new TypedTreeReadBuffer<T>();
40+
tree.SetBranchAddress(branchName.Data(), &buf->_value);
41+
return std::unique_ptr<TreeReadBuffer>(buf);
42+
}
43+
44+
#endif

roofit/roofitcore/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
ROOT_ADD_GTEST(simple simple.cxx LIBRARIES RooFitCore)
1111
ROOT_ADD_GTEST(testRooCacheManager testRooCacheManager.cxx LIBRARIES RooFitCore)
12+
ROOT_ADD_GTEST(testRooCategory testRooCategory.cxx LIBRARIES RooFitCore)
1213
ROOT_ADD_GTEST(testWorkspace testWorkspace.cxx LIBRARIES RooFitCore RooStats)
1314
if(NOT MSVC OR win_broken_tests)
1415
ROOT_ADD_GTEST(testRooDataHist testRooDataHist.cxx LIBRARIES RooFitCore
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Tests for the RooCategory
2+
// Author: Jonas Rembser, CERN 04/2021
3+
4+
#include <RooCategory.h>
5+
#include <RooDataSet.h>
6+
#include <RooGlobalFunc.h>
7+
8+
#include <TTree.h>
9+
10+
#include <gtest/gtest.h>
11+
12+
// GitHub issue 10278: RooDataSet incorrectly loads RooCategory values from TTree branch of type Short_t
13+
TEST(RooCategory, CategoryDefineMultiState)
14+
{
15+
TTree tree("test_tree", "Test tree");
16+
Short_t cat_in;
17+
tree.Branch("cat", &cat_in);
18+
19+
cat_in = 2; // category B
20+
tree.Fill();
21+
22+
RooCategory cat("cat", "Category", {{"B_cat", 2}, {"A_cat", 3}});
23+
RooDataSet data("data", "RooDataSet", RooArgSet(cat), RooFit::Import(tree));
24+
25+
EXPECT_EQ(static_cast<RooCategory &>((*data.get(0))["cat"]).getCurrentIndex(), 2);
26+
}

0 commit comments

Comments
 (0)