4
4
#include " tests/util/util.h"
5
5
#include " core/compiler.h"
6
6
7
- // aten::_convolution(Tensor input, Tensor weight,
7
+ // aten::_convolution(Tensor input, Tensor weight,
8
8
// Tensor? bias, int[] stride, int[] padding,
9
- // int[] dilation, bool transposed,
10
- // int[] output_padding, int groups, bool benchmark,
9
+ // int[] dilation, bool transposed,
10
+ // int[] output_padding, int groups, bool benchmark,
11
11
// bool deterministic, bool cudnn_enabled) -> (Tensor)
12
12
13
13
void conv_test_helper (std::string graph_ir) {
14
14
auto g = std::make_shared<torch::jit::Graph>();
15
15
torch::jit::script::parseIR (graph_ir, &*g);
16
-
16
+
17
17
auto in = at::randint (1 , 10 , {1 , 3 , 10 , 10 }, {at::kCUDA });
18
18
auto w = at::randint (1 , 10 , {8 , 3 , 5 , 5 }, {at::kCUDA });
19
19
auto b = at::randint (1 , 10 , {8 }, {at::kCUDA });
20
20
21
21
auto jit_in = at::clone (in);
22
22
auto jit_w = at::clone (w);
23
23
auto jit_b = at::clone (b);
24
-
24
+
25
25
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
26
26
auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
27
27
@@ -31,11 +31,11 @@ void conv_test_helper(std::string graph_ir) {
31
31
params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
32
32
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
33
33
34
- auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
34
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
35
35
36
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt));
36
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
37
37
}
38
-
38
+
39
39
TEST (Converters, ATenConvolutionConvertsCorrectly) {
40
40
const auto graph = R"IR(
41
41
graph(%0 : Tensor,
@@ -45,7 +45,7 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) {
45
45
%4 : int = prim::Constant[value=0]()
46
46
%5 : int = prim::Constant[value=1]()
47
47
%6 : int = prim::Constant[value=0]()
48
- %7 : bool = prim::Constant[value=0]()
48
+ %7 : bool = prim::Constant[value=0]()
49
49
%8 : int[] = prim::ListConstruct(%3, %3)
50
50
%9 : int[] = prim::ListConstruct(%4, %4)
51
51
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -55,15 +55,15 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) {
55
55
56
56
auto g = std::make_shared<torch::jit::Graph>();
57
57
torch::jit::script::parseIR (graph, &*g);
58
-
58
+
59
59
auto in = at::randint (1 , 10 , {1 , 3 , 10 , 10 }, {at::kCUDA });
60
60
auto w = at::randint (1 , 10 , {8 , 3 , 5 , 5 }, {at::kCUDA });
61
61
auto b = at::randint (1 , 10 , {8 }, {at::kCUDA });
62
62
63
63
auto jit_in = at::clone (in);
64
64
auto jit_w = at::clone (w);
65
65
auto jit_b = at::clone (b);
66
-
66
+
67
67
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
68
68
auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
69
69
@@ -73,9 +73,9 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) {
73
73
params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
74
74
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
75
75
76
- auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
76
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
77
77
78
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt));
78
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
79
79
}
80
80
81
81
TEST (Converters, ATenConvolutionNoBiasConvertsCorrectly) {
@@ -87,7 +87,7 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
87
87
%4 : int = prim::Constant[value=0]()
88
88
%5 : int = prim::Constant[value=1]()
89
89
%6 : int = prim::Constant[value=0]()
90
- %7 : bool = prim::Constant[value=0]()
90
+ %7 : bool = prim::Constant[value=0]()
91
91
%8 : int[] = prim::ListConstruct(%3, %3)
92
92
%9 : int[] = prim::ListConstruct(%4, %4)
93
93
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -97,12 +97,12 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
97
97
98
98
auto g = std::make_shared<torch::jit::Graph>();
99
99
torch::jit::script::parseIR (graph, &*g);
100
-
100
+
101
101
auto in = at::randint (1 , 2 , {1 , 1 , 3 , 3 }, {at::kCUDA });
102
102
auto w = at::randint (1 , 2 , {4 , 1 , 2 , 2 }, {at::kCUDA });
103
103
104
104
auto jit_in = at::clone (in);
105
- auto jit_w = at::clone (w);
105
+ auto jit_w = at::clone (w);
106
106
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w});
107
107
auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
108
108
@@ -111,9 +111,9 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
111
111
params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w});
112
112
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
113
113
114
- auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
114
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
115
115
116
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt));
116
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
117
117
}
118
118
119
119
@@ -126,7 +126,7 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) {
126
126
%4 : int = prim::Constant[value=0]()
127
127
%5 : int = prim::Constant[value=1]()
128
128
%6 : int = prim::Constant[value=0]()
129
- %7 : bool = prim::Constant[value=0]()
129
+ %7 : bool = prim::Constant[value=0]()
130
130
%8 : int[] = prim::ListConstruct(%3, %3)
131
131
%9 : int[] = prim::ListConstruct(%4, %4)
132
132
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -137,15 +137,15 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) {
137
137
138
138
auto g = std::make_shared<torch::jit::Graph>();
139
139
torch::jit::script::parseIR (graph, &*g);
140
-
140
+
141
141
auto in = at::randint (1 , 10 , {1 , 3 , 9 , 9 }, {at::kCUDA });
142
142
auto w = at::randint (1 , 10 , {4 , 3 , 3 , 3 }, {at::kCUDA });
143
143
auto b = at::randint (1 , 10 , {4 }, {at::kCUDA });
144
144
145
145
auto jit_in = at::clone (in);
146
146
auto jit_w = at::clone (w);
147
147
auto jit_b = at::clone (b);
148
-
148
+
149
149
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
150
150
auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
151
151
@@ -155,9 +155,9 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) {
155
155
params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
156
156
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
157
157
158
- auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
158
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
159
159
160
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt));
160
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
161
161
}
162
162
163
163
TEST (Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
@@ -169,7 +169,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
169
169
%4 : int = prim::Constant[value=2]()
170
170
%5 : int = prim::Constant[value=1]()
171
171
%6 : int = prim::Constant[value=0]()
172
- %7 : bool = prim::Constant[value=0]()
172
+ %7 : bool = prim::Constant[value=0]()
173
173
%8 : int[] = prim::ListConstruct(%3, %3)
174
174
%9 : int[] = prim::ListConstruct(%4, %4)
175
175
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -180,15 +180,15 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
180
180
181
181
auto g = std::make_shared<torch::jit::Graph>();
182
182
torch::jit::script::parseIR (graph, &*g);
183
-
183
+
184
184
auto in = at::randint (1 , 10 , {1 , 3 , 4 , 4 }, {at::kCUDA });
185
185
auto w = at::randint (1 , 10 , {4 , 3 , 2 , 2 }, {at::kCUDA });
186
186
auto b = at::randint (1 , 10 , {4 }, {at::kCUDA });
187
187
188
188
auto jit_in = at::clone (in);
189
189
auto jit_w = at::clone (w);
190
190
auto jit_b = at::clone (b);
191
-
191
+
192
192
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
193
193
auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
194
194
@@ -198,9 +198,9 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
198
198
params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
199
199
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
200
200
201
- auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
201
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
202
202
203
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt));
203
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
204
204
}
205
205
206
206
// TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
@@ -212,7 +212,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
212
212
// %4 : int = prim::Constant[value=0]()
213
213
// %5 : int = prim::Constant[value=2]()
214
214
// %6 : int = prim::Constant[value=0]()
215
- // %7 : bool = prim::Constant[value=0]()
215
+ // %7 : bool = prim::Constant[value=0]()
216
216
// %8 : int[] = prim::ListConstruct(%3, %3)
217
217
// %9 : int[] = prim::ListConstruct(%4, %4)
218
218
// %10 : int[] = prim::ListConstruct(%5, %5)
@@ -233,7 +233,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
233
233
// %4 : int = prim::Constant[value=0]()
234
234
// %5 : int = prim::Constant[value=1]()
235
235
// %6 : int = prim::Constant[value=2]()
236
- // %7 : bool = prim::Constant[value=0]()
236
+ // %7 : bool = prim::Constant[value=0]()
237
237
// %8 : int[] = prim::ListConstruct(%3, %3)
238
238
// %9 : int[] = prim::ListConstruct(%4, %4)
239
239
// %10 : int[] = prim::ListConstruct(%5, %5)
@@ -254,14 +254,14 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
254
254
// %4 : int = prim::Constant[value=0]()
255
255
// %5 : int = prim::Constant[value=1]()
256
256
// %6 : int = prim::Constant[value=0]()
257
- // %7 : bool = prim::Constant[value=0]()
257
+ // %7 : bool = prim::Constant[value=0]()
258
258
// %8 : int[] = prim::ListConstruct(%3, %3)
259
259
// %9 : int[] = prim::ListConstruct(%4, %4)
260
260
// %10 : int[] = prim::ListConstruct(%5, %5)
261
261
// %11 : int[] = prim::ListConstruct(%6, %6)
262
262
// %12 : int = prim::Constant[value=2]()
263
263
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
264
264
// return (%13))IR";
265
-
265
+
266
266
// conv_test_helper(graph);
267
267
// }
0 commit comments