Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tmva/tmva/inc/TMVA/RVariablePlotter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// @(#)root/tmva $Id$
// Authors: Simone Azeglio, Lorenzo Moneta , Stefan Wunsch

/*************************************************************************************
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
* Package: TMVA *
* Class : RVariablePlotter *
* Web : http://tmva.sourceforge.net *
* *
* Description: *
* Variable Plotter *
* *
* Authors (alphabetical): *
* Simone Azeglio, University of Turin (Master Student), CERN (Summer Student) *
* Lorenzo Moneta, CERN *
* Stefan Wunsch *
* *
* Copyright (c) 2021: *
* *
* Redistribution and use in source and binary forms, with or without *
* modification, are permitted according to the terms listed in LICENSE *
* (http://tmva.sourceforge.net/LICENSE) *
**********************************************************************************/

#ifndef ROOT_TMVA_RVariablePlotter
#define ROOT_TMVA_RVariablePlotter

//////////////////////////////////////////////////////////////////////////
// //
// RVariablePlotter //
// //
// Variable Plotter //
// //
//////////////////////////////////////////////////////////////////////////

#include <vector>
#include <string>

#include "TLegend.h"
#include "TH1D.h"
#include "THStack.h"

#include "ROOT/RDataFrame.hxx"
#include "ROOT/RDF/RInterface.hxx"

#include "TMVA/RTensor.hxx"


namespace TMVA {

class RVariablePlotter {

public:

// constructor - RDataFrame input
RVariablePlotter(const std::vector<ROOT::RDF::RNode>& nodes, const std::vector<std::string>& labels);

// constructor - Tensor input
RVariablePlotter( const std::vector<TMVA::Experimental::RTensor<float>>& tensors, const std::vector<std::string>& labels);

// draw variables plot - RDataFrame input
void Draw(const std::string& variable);

// convert vector of RTensors to vector of RDataframes
std::vector<ROOT::RDF::RNode> TensorsToNodes();

// draw variables plot - RTensor input
//void DrawTensor(const std::string& variable);

// draw legend
void DrawLegend(float minX, float minY, float maxX, float maxY);


private:

std::vector<ROOT::RDF::RNode> fNodes; //! transient
std::vector<TMVA::Experimental::RTensor<float>> fTensors; //! transient
std::vector<std::string> fLabels;

// convert RTensor to RDataframe
ROOT::RDF::RNode TensorToNode(const TMVA::Experimental::RTensor<float>& tensor);




// flag if "boundary vector" is owned by the volume of not
};

} // namespace TMVA





#endif /* ROOT_TMVA_RVariablePlotter */
186 changes: 186 additions & 0 deletions tmva/tmva/src/RVariablePlotter.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// @(#)root/tmva $Id$
// Authors: Simone Azeglio, Lorenzo Moneta , Stefan Wunsch

/*************************************************************************************
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
* Package: TMVA *
* Class : RVariablePlotter *
* Web : http://tmva.sourceforge.net *
* *
* Description: *
* Variable Plotter *
* *
* Authors (alphabetical): *
* Simone Azeglio, University of Turin (Master Student), CERN (Summer Student) *
* Lorenzo Moneta, CERN *
* Stefan Wunsch *
* *
* Copyright (c) 2021: *
* *
* Redistribution and use in source and binary forms, with or without *
* modification, are permitted according to the terms listed in LICENSE *
* (http://tmva.sourceforge.net/LICENSE) *
**********************************************************************************/


/*! \class TMVA::RVariablePlotter
\ingroup TMVA
Plotting a single variable
*/

#include "TMVA/RVariablePlotter.h"
#include "TMVA/tmvaglob.h"
#include "TMVA/RTensor.hxx"

using namespace TMVA::Experimental;


////////////////////////////////////////////////////////////////////////////////
/// constructor for RDataframe with nodes (samples) and labels

TMVA::RVariablePlotter::RVariablePlotter( const std::vector<ROOT::RDF::RNode>& nodes, const std::vector<std::string>& labels)
: fNodes(nodes),
fLabels(labels){

if (fNodes.size() != fLabels.size())
std::runtime_error("Number of given RDataFrame nodes does not match number of given class labels.");

if (fNodes.size() == 0)
std::runtime_error("Number of given RDataFrame nodes and number of given class labels cannot be zero.");
}


////////////////////////////////////////////////////////////////////////////////
/// constructor for RTensor nodes (samples) and labels

TMVA::RVariablePlotter::RVariablePlotter( const std::vector<TMVA::Experimental::RTensor<float>>& tensors, const std::vector<std::string>& labels)
: fTensors(tensors),
fLabels(labels)
{

if (fTensors.size() != fLabels.size())
std::runtime_error("Number of given RTensor components does not match number of given class labels.");

if (fTensors.size() == 0)
std::runtime_error("Number of given RTensor components and number of given class labels cannot be zero.");
}



////////////////////////////////////////////////////////////////////////////////
/// Drawing variables' plot RDataframe

void TMVA::RVariablePlotter::Draw(const std::string& variable) {
// Make histograms with TH1D

TMVA::TMVAGlob::Initialize( TMVAGlob::SetTMVAStyle );

const auto size = fNodes.size();
std::vector<ROOT::RDF::RResultPtr<TH1D>> histos;

for (std::size_t i = 0; i < size; i++) {
// Trigger event loop with computing the histogram
auto h = fNodes[i].Histo1D(variable);
histos.push_back(h);
}

// Modify style and draw histograms
THStack stack;

for (unsigned int i = 0; i < histos.size(); i++) {
histos[i]->SetLineColor(i + 1);
/*if (i == 0) {
histos[i]->SetTitle("");
histos[i]->SetStats(false);
}
*/
histos[i]->SetTitle(variable.c_str());
histos[i]->SetStats(true);
stack.Add((TH1*) histos[i].GetPtr()->Clone());
}

auto clone = (THStack*) stack.DrawClone("nostack");

//clone->SetTitle(variable.c_str());
//clone->GetXaxis()->SetTitle(variable.c_str());
//clone->GetYaxis()->SetTitle("Count");


}


////////////////////////////////////////////////////////////////////////////////
/// RTensor to RNode Converter

// add variables argument for custom variables
ROOT::RDF::RNode TMVA::RVariablePlotter::TensorToNode(const TMVA::Experimental::RTensor<float>& tensor){

// shape check
if (tensor.GetShape().size() != 2)
std::runtime_error("Number of given RTensor dimensions does not match number of given class labels.");

// Get Tensor Data
const auto dataSig = tensor.GetData();
std::vector<float> vecSig;

for (int i = 0; i< tensor.GetShape()[0]; i++){
vecSig.push_back(dataSig[i]);}

std::string nameVar = "var" + std::to_string(1);
std::vector<std::string> names;
names.push_back(nameVar);

auto dfSig = ROOT::RDataFrame(tensor.GetShape()[0]).Define(names[0], "vecSig" );

for (int j = 1; j < tensor.GetShape()[1]; j++){
vecSig = std::vector<float>();

for (int i = j*tensor.GetShape()[0]; i < (j+1)*tensor.GetShape()[0]; i++){
vecSig.push_back(dataSig[i]);

}

nameVar = "var" + std::to_string(j+1);
names.push_back(nameVar);

dfSig = dfSig.Define(names[j], "vecSig");
}

auto DFNode = ROOT::RDF::RNode(dfSig);

return DFNode;
}

////////////////////////////////////////////////////////////////////////////////
/// RTensor vector to RNode vector converter
std::vector<ROOT::RDF::RNode> TMVA::RVariablePlotter::TensorsToNodes(){

const auto size = fTensors.size();
std::vector<ROOT::RDF::RNode> NodeVec;

// loop through tensors
for (int k = 0; k < fTensors.size(); k++){

NodeVec.push_back(TensorToNode(fTensors[k]));
}

return NodeVec;

}


////////////////////////////////////////////////////////////////////////////////
/// Drawing Legend

void TMVA::RVariablePlotter::DrawLegend(float minX = 0.8, float minY = 0.8, float maxX = 0.9, float maxY = 0.9) {
// make Legend from TLegend
TLegend l(minX, minY, maxX, maxY);
std::vector<TH1D> histos(fLabels.size());

for (unsigned int i = 0; i < fLabels.size(); i++) {
histos[i].SetLineColor(i + 1);
l.AddEntry(&histos[i], fLabels[i].c_str(), "l");
}
l.SetBorderSize(1);
l.DrawClone();
}