@@ -13,7 +13,7 @@ TEST(Evaluators, DivIntEvaluatesCorrectly) {
13
13
return (%3))IR" ;
14
14
15
15
auto g = std::make_shared<torch::jit::Graph>();
16
- torch::jit::parseIR (graph, &*g );
16
+ torch::jit::parseIR (graph, g. get () );
17
17
18
18
auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
19
19
auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
@@ -30,7 +30,7 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
30
30
return (%3))IR" ;
31
31
32
32
auto g = std::make_shared<torch::jit::Graph>();
33
- torch::jit::parseIR (graph, &*g );
33
+ torch::jit::parseIR (graph, g. get () );
34
34
35
35
auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
36
36
auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
@@ -49,7 +49,7 @@ TEST(Evaluators, ZerosEvaluatesCorrectly) {
49
49
auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
50
50
51
51
auto g = std::make_shared<torch::jit::Graph>();
52
- torch::jit::parseIR (graph, &*g );
52
+ torch::jit::parseIR (graph, g. get () );
53
53
54
54
auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {in});
55
55
auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {in});
@@ -69,15 +69,118 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
69
69
auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
70
70
71
71
auto g = std::make_shared<torch::jit::Graph>();
72
- torch::jit::parseIR (graph, &*g );
72
+ torch::jit::parseIR (graph, g. get () );
73
73
74
74
auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {in});
75
75
auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {in});
76
76
77
77
ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
78
78
}
79
79
80
- TEST (Evaluators, SizeConvertsCorrectly) {
80
+ TEST (Evaluators, ATenArangeIntEvaluatesCorrectly) {
81
+ const auto graph = R"IR(
82
+ graph():
83
+ %0 : int = prim::Constant[value=51]()
84
+ %1 : None = prim::Constant()
85
+ %2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
86
+ return (%2))IR" ;
87
+
88
+ auto g = std::make_shared<torch::jit::Graph>();
89
+ torch::jit::parseIR (graph, &*g);
90
+
91
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
92
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
93
+
94
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
95
+ }
96
+
97
+ TEST (Evaluators, ATenArangeFloatEvaluatesCorrectly) {
98
+ const auto graph = R"IR(
99
+ graph():
100
+ %0 : float = prim::Constant[value=51.2]()
101
+ %1 : None = prim::Constant()
102
+ %2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
103
+ return (%2))IR" ;
104
+
105
+ auto g = std::make_shared<torch::jit::Graph>();
106
+ torch::jit::parseIR (graph, &*g);
107
+
108
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
109
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
110
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
111
+ }
112
+
113
+ TEST (Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) {
114
+ const auto graph = R"IR(
115
+ graph():
116
+ %0 : int = prim::Constant[value=1]()
117
+ %1 : int = prim::Constant[value=51]()
118
+ %2 : None = prim::Constant()
119
+ %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
120
+ return (%3))IR" ;
121
+
122
+ auto g = std::make_shared<torch::jit::Graph>();
123
+ torch::jit::parseIR (graph, &*g);
124
+
125
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
126
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
127
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
128
+ }
129
+
130
+ TEST (Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) {
131
+ const auto graph = R"IR(
132
+ graph():
133
+ %0 : float = prim::Constant[value=1.5]()
134
+ %1 : float = prim::Constant[value=51.2]()
135
+ %2 : None = prim::Constant()
136
+ %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
137
+ return (%3))IR" ;
138
+
139
+ auto g = std::make_shared<torch::jit::Graph>();
140
+ torch::jit::parseIR (graph, &*g);
141
+
142
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
143
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
144
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
145
+ }
146
+
147
+ TEST (Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) {
148
+ const auto graph = R"IR(
149
+ graph():
150
+ %0 : int = prim::Constant[value=1]()
151
+ %1 : int = prim::Constant[value=51]()
152
+ %2 : int = prim::Constant[value=1]()
153
+ %3 : None = prim::Constant()
154
+ %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
155
+ return (%4))IR" ;
156
+
157
+ auto g = std::make_shared<torch::jit::Graph>();
158
+ torch::jit::parseIR (graph, &*g);
159
+
160
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
161
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
162
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
163
+ }
164
+
165
+ TEST (Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
166
+ const auto graph = R"IR(
167
+ graph():
168
+ %0 : float = prim::Constant[value=1.2]()
169
+ %1 : float = prim::Constant[value=51.6]()
170
+ %2 : float = prim::Constant[value=1.5]()
171
+ %3 : None = prim::Constant()
172
+ %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
173
+ return (%4))IR" ;
174
+
175
+ auto g = std::make_shared<torch::jit::Graph>();
176
+ torch::jit::parseIR (graph, &*g);
177
+
178
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
179
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
180
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
181
+ }
182
+
183
+ TEST (Evaluators, ATenSizeNegativeConvertsCorrectly) {
81
184
const auto graph = R"IR(
82
185
graph(%0 : Tensor):
83
186
%1 : int = prim::Constant[value=-1]()
0 commit comments