Skip to content

Commit 2c97479

Browse files
[GSOC][TMVA][SOFIE] Cast ONNX Operator implemented with the corresponding unit tests (root-project#11033)
* Cast ONNX Operator implemented with the corresponding unit tests Added the functionality and support of int input type in Cast ONNX Operator * Attribute type added to ROperator Cast Class and modified the RModel Parser for supporting different input types * The functionality and support for other datatypes added for the cast operator and added the support to RModel::Generate method also * made the required changes related to support for other datatypes * Extended the support for other datatypes in the infer function * Required changes made related to support of different datatypes in SOFIE * Apply various fixes to support different input/output type. This fixes the new Cast operator. Several changes are needed for Cast since the input tensor can be of a type different than float Apply also a fix for parsing correctly the attribute of Cast * The attribute fattr_type changed to fAttrType Co-authored-by: moneta <[email protected]>
1 parent 2e28217 commit 2c97479

File tree

10 files changed

+241
-31
lines changed

10 files changed

+241
-31
lines changed

tmva/sofie/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
3333
TMVA/ROperator_Concat.hxx
3434
TMVA/ROperator_Identity.hxx
3535
TMVA/ROperator_Softmax.hxx
36+
TMVA/ROperator_Cast.hxx
3637
TMVA/SOFIE_common.hxx
3738
TMVA/SOFIEHelpers.hxx
3839
SOURCES

tmva/sofie/inc/TMVA/OperatorList.hxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
#include "TMVA/ROperator_Identity.hxx"
1919
#include "TMVA/ROperator_Softmax.hxx"
2020
#include "TMVA/ROperator_Concat.hxx"
21+
#include "TMVA/ROperator_Cast.hxx"
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#ifndef TMVA_SOFIE_ROPERATOR_Cast
2+
#define TMVA_SOFIE_ROPERATOR_Cast
3+
4+
#include "TMVA/SOFIE_common.hxx"
5+
#include "TMVA/ROperator.hxx"
6+
#include "TMVA/RModel.hxx"
7+
8+
#include <sstream>
9+
10+
namespace TMVA{
11+
namespace Experimental{
12+
namespace SOFIE{
13+
14+
15+
template <typename T>
16+
class ROperator_Cast final : public ROperator
17+
{
18+
19+
private:
20+
21+
std::string fNX;
22+
std::string fNY;
23+
std::vector<size_t> fShape;
24+
std::string fAttrType = "float";
25+
26+
public:
27+
ROperator_Cast(){}
28+
ROperator_Cast(std::string attr_type,std::string nameX, std::string nameY):
29+
fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)),
30+
fAttrType(attr_type) {}
31+
32+
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
33+
return input;
34+
}
35+
36+
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
37+
auto ret = input; //suggest copy to compiler
38+
return ret;
39+
}
40+
41+
void Initialize(RModel& model){
42+
//input must be a graph input, or already initialized intermediate tensor
43+
if (model.CheckIfTensorAlreadyExist(fNX) == false){
44+
throw std::runtime_error("TMVA SOFIE Cast Op Input Tensor is not found in model");
45+
}
46+
fShape = model.GetTensorShape(fNX);
47+
model.AddIntermediateTensor(fNY, ConvertStringToType(fAttrType), fShape);
48+
}
49+
50+
51+
std::string Generate(std::string OpName){
52+
OpName = "op_" + OpName;
53+
if (fShape.empty()) {
54+
throw std::runtime_error("TMVA SOFIE Cast called to Generate without being initialized first");
55+
}
56+
std::stringstream out;
57+
size_t length = ConvertShapeToLength(fShape);
58+
59+
// out << SP << ETensorType << " " << OpName << "_attr = " << fattr << ";\n";
60+
out << "\n//------ CAST\n";
61+
out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
62+
63+
out << SP << SP << "tensor_" << fNY << "[id] = static_cast<"<< fAttrType << ">(tensor_" << fNX << "[id]);\n";
64+
65+
out << SP << "}\n";
66+
return out.str();
67+
}
68+
69+
};
70+
71+
}//SOFIE
72+
}//Experimental
73+
}//TMVA
74+
75+
76+
#endif //TMVA_SOFIE_ROPERATOR_Cast

tmva/sofie/src/RModel.cxx

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,19 @@ namespace SOFIE{
277277
}
278278
}
279279
for (auto&i: fIntermediateTensorInfos){
280+
size_t length = ConvertShapeToLength(i.second.shape);
280281
if (i.second.type == ETensorType::FLOAT){
281-
size_t length = 1;
282-
for (auto & dim: i.second.shape){
283-
length *= dim;
284-
}
285-
//fGC += "float tensor_" + i.first + "[" + std::to_string(length) + "];\n";
286282
fGC += "std::vector<float> fTensor_" + i.first + " = std::vector<float>(" + std::to_string(length) + ");\n";
287283
fGC += "float * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
288284
}
285+
if (i.second.type == ETensorType::DOUBLE){
286+
fGC += "std::vector<double> fTensor_" + i.first + " = std::vector<double>(" + std::to_string(length) + ");\n";
287+
fGC += "double * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
288+
}
289+
if (i.second.type == ETensorType::INT64){
290+
fGC += "std::vector<int64_t> fTensor_" + i.first + " = std::vector<int64_t>(" + std::to_string(length) + ");\n";
291+
fGC += "int64_t * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
292+
}
289293
}
290294
if (fUseSession) {
291295
// add here specific operator code that needs to define session data members
@@ -310,14 +314,15 @@ namespace SOFIE{
310314
}
311315

312316
size_t outputSize = fOutputTensorNames.size();
317+
// assume output types are all the same
318+
std::string outputType;
313319
if (outputSize == 1) {
314320
auto f = fIntermediateTensorInfos.find(fOutputTensorNames[0]);
315321
if (f == fIntermediateTensorInfos.end()){
316322
throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[0] + " not found when trying to get its info");
317323
}else{
318-
if (f->second.type == ETensorType::FLOAT){
319-
fGC += "std::vector<float> ";
320-
}
324+
outputType = ConvertTypeToString(f->second.type);
325+
fGC += "std::vector<" + outputType + "> ";
321326
}
322327
} else {
323328
std::vector<ETensorType> outputTensorsTypes(outputSize);
@@ -330,45 +335,55 @@ namespace SOFIE{
330335
outputTensorsTypes[i] = f->second.type;
331336
}
332337
}
333-
ETensorType outputType = outputTensorsTypes[0];
338+
// assume all output types are the same
339+
outputType = ConvertTypeToString(outputTensorsTypes[0]);
334340
for (size_t i = 0; i < outputSize; i++) {
335-
if (outputTensorsTypes[i] != outputType) {
341+
if (outputTensorsTypes[i] != outputTensorsTypes[0]) {
336342
throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[i] + " is of different type.");
337343
}
338344
}
339-
if (outputType == ETensorType::FLOAT) {
340-
fGC += "std::vector<std::vector<float>> ";
341-
}
345+
fGC += "std::vector<std::vector<" + outputType + ">> ";
342346
}
343347

344348
fGC += "infer(";
345349
for (auto& i: fReadyInputTensorInfos){
346350
if (i.second.type == ETensorType::FLOAT){
347-
fGC += "float* tensor_" + i.first + ",";
351+
fGC += "float* tensor_" + i.first + ",";
352+
}
353+
else if (i.second.type == ETensorType::INT32 ){
354+
fGC += "int32_t* tensor_" + i.first + ",";
355+
}
356+
else if (i.second.type == ETensorType::INT64){
357+
fGC += "int64_t* tensor_" + i.first + ",";
358+
}
359+
else if(i.second.type == ETensorType::DOUBLE){
360+
fGC += "double* tensor_" + i.first + ",";
348361
}
349362
}
350363
fGC.pop_back(); //remove last ","
351364
fGC += "){\n";
352365

366+
const std::string SP = " ";
367+
353368
for (size_t id = 0; id < fOperators.size() ; id++){
354369
fGC+= (fOperators[id]->Generate(std::to_string(id)));
355370
}
356371
if (outputSize == 1) {
357372
size_t outputLength = ConvertShapeToLength(GetTensorShape(fOutputTensorNames[0]));
358373

359-
fGC += "\tstd::vector<float> ret (tensor_" + fOutputTensorNames[0] + ", tensor_" + fOutputTensorNames[0] + " + " +
374+
fGC += SP + "std::vector<" + outputType + "> ret (tensor_" + fOutputTensorNames[0] + ", tensor_" + fOutputTensorNames[0] + " + " +
360375
std::to_string(outputLength) + ");\n";
361376
} else {
362377
for (size_t i = 0; i < outputSize; i++) {
363378
if (!fOutputTensorNames[i].empty()) {
364379
size_t outputLength = ConvertShapeToLength(GetTensorShape(fOutputTensorNames[i]));
365-
fGC += "\tstd::vector<float> ret_";
380+
fGC += SP + "std::vector<" + outputType + "> ret_";
366381
fGC += std::to_string(i);
367382
fGC += " (tensor_" + fOutputTensorNames[i] + ", tensor_" + fOutputTensorNames[i] + " + " +
368383
std::to_string(outputLength) + ");\n";
369384
}
370385
}
371-
fGC += "\tstd::vector<std::vector<float>> ret({";
386+
fGC += SP + "std::vector<std::vector<" + outputType + ">> ret({";
372387
for (size_t i = 0; i < outputSize; i++) {
373388
if (fOutputTensorNames[i].empty()) {
374389
fGC += "{}";
@@ -382,7 +397,7 @@ namespace SOFIE{
382397
}
383398
fGC += "});\n";
384399
}
385-
fGC += "\treturn ret;\n";
400+
fGC += SP + "return ret;\n";
386401
fGC += "}\n";
387402
if (fUseSession) {
388403
fGC += "};\n";
@@ -394,7 +409,7 @@ namespace SOFIE{
394409
void RModel::ReadInitializedTensorsFromFile() {
395410
// generate the code to read initialized tensors from a text data file
396411
if (fInitializedTensors.empty()) return;
397-
412+
398413
fGC += " std::ifstream f;\n";
399414
fGC += " f.open(filename);\n";
400415
fGC += " if (!f.is_open()){\n";

tmva/sofie/src/SOFIE_common.cxx

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,43 @@ std::string ConvertTypeToString(ETensorType type){
2525
case ETensorType::FLOAT : {
2626
return "float";
2727
}
28+
case ETensorType::INT16 : {
29+
return "int16_t";
30+
}
31+
case ETensorType::INT32 : {
32+
return "int32_t";
33+
}
34+
case ETensorType::INT64 : {
35+
return "int64_t";
36+
}
37+
case ETensorType::UINT16 : {
38+
return "uint16_t";
39+
}
40+
case ETensorType::UINT32 : {
41+
return "uint32_t";
42+
}
43+
case ETensorType::UINT64 : {
44+
return "uint64_t";
45+
}
46+
case ETensorType::DOUBLE : {
47+
return "double";
48+
}
2849
default:{
2950
return "other";
3051
}
3152
}
3253
}
3354

3455
ETensorType ConvertStringToType(std::string type){
35-
if(type == "float32" || type == "Float"){
56+
if(type == "float32" || type == "float" || type == "Float"){
3657
return ETensorType::FLOAT;
3758
}
59+
else if(type == "int64"){
60+
return ETensorType::INT64;
61+
}
62+
else if (type == "double" || type == "float64"){
63+
return ETensorType::DOUBLE;
64+
}
3865
else{
3966
return ETensorType::UNDEFINED;
4067
}

tmva/sofie/test/TestCustomModelsFromONNX.cxx

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
#include "Neg_FromONNX.hxx"
3131
#include "input_models/references/Neg.ref.hxx"
3232

33+
#include "Cast_FromONNX.hxx"
34+
#include "input_models/references/Cast.ref.hxx"
35+
3336
#include "LinearWithLeakyRelu_FromONNX.hxx"
3437
#include "input_models/references/LinearWithLeakyRelu.ref.hxx"
3538

@@ -308,7 +311,7 @@ TEST(ONNX, Neg)
308311
-1.9100, 1.8811, -1.7269, -0.1094, -0.0145, 0.2509, 0.5893, -2.2733,
309312
-0.7077, 1.0645, -0.8607, 0.2085
310313
});
311-
314+
312315
TMVA_SOFIE_Neg::Session s("Neg_FromONNX.dat");
313316
std::vector<float> output = s.infer(input.data());
314317

@@ -323,6 +326,30 @@ TEST(ONNX, Neg)
323326
}
324327
}
325328

329+
TEST(ONNX, Cast)
330+
{
331+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
332+
333+
// Preparing the standard input
334+
std::vector<int64_t> input({
335+
1,2,3,4,5,6
336+
});
337+
338+
TMVA_SOFIE_Cast::Session s("Cast_FromONNX.dat");
339+
340+
auto output = s.infer(input.data());
341+
342+
// Checking output size
343+
EXPECT_EQ(output.size(), sizeof(Cast_ExpectedOutput::outputs) / sizeof(float));
344+
345+
float *correct = Cast_ExpectedOutput::outputs;
346+
347+
// Checking every output value, one by one
348+
for (size_t i = 0; i < output.size(); ++i) {
349+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
350+
}
351+
}
352+
326353
TEST(ONNX, Linear64)
327354
{
328355
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
@@ -589,12 +616,12 @@ TEST(ONNX, MaxPool1d){
589616
0.2283, 0.8947, 1.7627,
590617
-0.1657, 0.0649, -1.6066, 0.4162, -1.1525, -0.8184, 1.1324,
591618
-1.1086, 0.1061, 1.0071});
592-
619+
593620
TMVA_SOFIE_MaxPool1d::Session s("MaxPool1d_FromONNX.dat");
594621
std::vector<float> output = s.infer(input.data());
595622
// Checking output size
596623
EXPECT_EQ(output.size(), sizeof(MaxPool1d_ExpectedOutput::output) / sizeof(float));
597-
624+
598625
float *correct = MaxPool1d_ExpectedOutput::output;
599626

600627
// Checking every output value, one by one
@@ -620,12 +647,12 @@ TEST(ONNX, MaxPool2d){
620647
-0.9398, -0.2065, -0.9499, -0.9739, -0.1288, -0.1375, -1.2612,
621648
0.8810, 0.8506, 0.4455
622649
});
623-
650+
624651
TMVA_SOFIE_MaxPool2d::Session s("MaxPool2d_FromONNX.dat");
625652
std::vector<float> output = s.infer(input.data());
626653
// Checking output size
627654
EXPECT_EQ(output.size(), sizeof(MaxPool2d_ExpectedOutput::output) / sizeof(float));
628-
655+
629656
float *correct = MaxPool2d_ExpectedOutput::output;
630657

631658
// Checking every output value, one by one
@@ -652,12 +679,12 @@ TEST(ONNX, MaxPool3d){
652679
-0.5477, 0.2341, 0.9181,
653680
0.3842, 0.2428, 1.7924
654681
});
655-
682+
656683
TMVA_SOFIE_MaxPool3d::Session s("MaxPool3d_FromONNX.dat");
657684
std::vector<float> output = s.infer(input.data());
658685
// Checking output size
659686
EXPECT_EQ(output.size(), sizeof(MaxPool3d_ExpectedOutput::output) / sizeof(float));
660-
687+
661688
float *correct = MaxPool3d_ExpectedOutput::output;
662689

663690
// Checking every output value, one by one
@@ -683,12 +710,12 @@ TEST(ONNX, AvgPool){
683710
-1.4971, 0.5386, -0.2922, 0.4860, -0.3973, -0.4624, 0.4514,
684711
0.2385, 0.3783, -1.0500
685712
});
686-
713+
687714
TMVA_SOFIE_AvgPool::Session s("AvgPool_FromONNX.dat");
688715
std::vector<float> output = s.infer(input.data());
689716
// Checking output size
690717
EXPECT_EQ(output.size(), sizeof(AvgPool_ExpectedOutput::output) / sizeof(float));
691-
718+
692719
float *correct = AvgPool_ExpectedOutput::output;
693720

694721
// Checking every output value, one by one
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
pytorch1.11.0:s
2+
*
3+
onnx::Cast_01Cast_0"Cast*
4+
to �torch-jit-exportZ
5+
onnx::Cast_0
6+

7+

8+
b
9+
1
10+
 
11+

12+
B
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
namespace Cast_ExpectedOutput{
2+
float outputs[] = {
3+
1.0,2.0,3.0,
4+
4.0,5.0,6.0
5+
};
6+
} // namespace Cast_ExpectedOutput

0 commit comments

Comments
 (0)