forked from root-project/root
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathTMVA_SOFIE_Keras.C
More file actions
78 lines (64 loc) · 2.14 KB
/
TMVA_SOFIE_Keras.C
File metadata and controls
78 lines (64 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/// \file
/// \ingroup tutorial_ml
/// \notebook -nodraw
/// This macro provides a simple example for the parsing of Keras .h5 file
/// into RModel object and further generating the .hxx header files for inference.
///
/// \macro_code
/// \macro_output
/// \author Sanjiban Sengupta
using namespace TMVA::Experimental;
TString pythonSrc = "\
import os\n\
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n\
\n\
import numpy as np\n\
from tensorflow.keras.models import Model\n\
from tensorflow.keras.layers import Input,Dense,Activation,ReLU\n\
from tensorflow.keras.optimizers import SGD\n\
\n\
input=Input(shape=(64,),batch_size=4)\n\
x=Dense(32)(input)\n\
x=Activation('relu')(x)\n\
x=Dense(16,activation='relu')(x)\n\
x=Dense(8,activation='relu')(x)\n\
x=Dense(4)(x)\n\
output=ReLU()(x)\n\
model=Model(inputs=input,outputs=output)\n\
\n\
randomGenerator=np.random.RandomState(0)\n\
x_train=randomGenerator.rand(4,64)\n\
y_train=randomGenerator.rand(4,4)\n\
\n\
model.compile(loss='mean_squared_error', optimizer=SGD(learning_rate=0.01))\n\
model.fit(x_train, y_train, epochs=5, batch_size=4)\n\
model.save('KerasModel.h5')\n";
void TMVA_SOFIE_Keras(const char * modelFile = nullptr, bool printModelInfo = true){
// Running the Python script to generate Keras .h5 file
if (modelFile == nullptr) {
TMacro m;
m.AddLine(pythonSrc);
m.SaveSource("make_keras_model.py");
gSystem->Exec("python3 make_keras_model.py");
modelFile = "KerasModel.h5";
}
//Parsing the saved Keras .h5 file into RModel object
SOFIE::RModel model = SOFIE::PyKeras::Parse(modelFile);
//Generating inference code
model.Generate();
// generate output header. By default it will be modelName.hxx
model.OutputGenerated();
if (!printModelInfo) return;
//Printing required input tensors
std::cout<<"\n\n";
model.PrintRequiredInputTensors();
//Printing initialized tensors (weights)
std::cout<<"\n\n";
model.PrintInitializedTensors();
//Printing intermediate tensors
std::cout<<"\n\n";
model.PrintIntermediateTensors();
//Printing generated inference code
std::cout<<"\n\n";
model.PrintGenerated();
}