1
+ #include < iostream>
2
+ #include < string>
3
+ #include " core/compiler.h"
4
+ #include " gtest/gtest.h"
5
+ #include " tests/util/util.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
7
+
8
+ TEST (Converters, ATenReplication_pad1dTensorConvertsCorrectly) {
9
+ const auto graph = R"IR(
10
+ graph(%0 : Tensor):
11
+ %1 : int[] = prim::Constant[value=[2, 3]]()
12
+ %2 : Tensor = aten::replication_pad1d(%0, %1)
13
+ return (%2))IR" ;
14
+
15
+ auto g = std::make_shared<torch::jit::Graph>();
16
+ torch::jit::parseIR (graph, g.get ());
17
+
18
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 }, {at::kCUDA });
19
+
20
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
21
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
22
+
23
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
24
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
25
+
26
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
27
+ }
28
+
29
+ TEST (Converters, ATenReplication_pad1dRightZeroTensorConvertsCorrectly) {
30
+ const auto graph = R"IR(
31
+ graph(%0 : Tensor):
32
+ %1 : int[] = prim::Constant[value=[2, 0]]()
33
+ %2 : Tensor = aten::replication_pad1d(%0, %1)
34
+ return (%2))IR" ;
35
+
36
+ auto g = std::make_shared<torch::jit::Graph>();
37
+ torch::jit::parseIR (graph, g.get ());
38
+
39
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 }, {at::kCUDA });
40
+
41
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
42
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
43
+
44
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
45
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
46
+
47
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
48
+ }
49
+
50
+ TEST (Converters, ATenReplication_pad1dTensorConvertsCorrectlyWithDynamic) {
51
+ const auto graph = R"IR(
52
+ graph(%0 : Tensor):
53
+ %1 : int[] = prim::Constant[value=[2, 3]]()
54
+ %2 : Tensor = aten::replication_pad1d(%0, %1)
55
+ return (%2))IR" ;
56
+
57
+ auto g = std::make_shared<torch::jit::Graph>();
58
+ torch::jit::parseIR (graph, g.get ());
59
+
60
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 }, {at::kCUDA });
61
+
62
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
63
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
64
+
65
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
66
+ auto trt_results = trtorch::tests::util::RunGraphEngineDynamic (g, params, {in1});
67
+
68
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
69
+ }
70
+
71
+ TEST (Converters, ATenReplication_pad2dTensorConvertsCorrectly) {
72
+ const auto graph = R"IR(
73
+ graph(%0 : Tensor):
74
+ %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
75
+ %2 : Tensor = aten::replication_pad2d(%0, %1)
76
+ return (%2))IR" ;
77
+
78
+ auto g = std::make_shared<torch::jit::Graph>();
79
+ torch::jit::parseIR (graph, g.get ());
80
+
81
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 , 5 }, {at::kCUDA });
82
+
83
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
84
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
85
+
86
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
87
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
88
+
89
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
90
+ }
91
+
92
+ TEST (Converters, ATenReplication_pad2dRightBottomZeroTensorConvertsCorrectly) {
93
+ const auto graph = R"IR(
94
+ graph(%0 : Tensor):
95
+ %1 : int[] = prim::Constant[value=[2, 0, 2, 0]]()
96
+ %2 : Tensor = aten::replication_pad2d(%0, %1)
97
+ return (%2))IR" ;
98
+
99
+ auto g = std::make_shared<torch::jit::Graph>();
100
+ torch::jit::parseIR (graph, g.get ());
101
+
102
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 , 5 }, {at::kCUDA });
103
+
104
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
105
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
106
+
107
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
108
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
109
+
110
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
111
+ }
112
+
113
+ TEST (Converters, ATenReplication_pad2dTensorConvertsCorrectlyWithDynamic) {
114
+ const auto graph = R"IR(
115
+ graph(%0 : Tensor):
116
+ %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
117
+ %2 : Tensor = aten::replication_pad2d(%0, %1)
118
+ return (%2))IR" ;
119
+
120
+ auto g = std::make_shared<torch::jit::Graph>();
121
+ torch::jit::parseIR (graph, g.get ());
122
+
123
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 , 5 }, {at::kCUDA });
124
+
125
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
126
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
127
+
128
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
129
+ auto trt_results = trtorch::tests::util::RunGraphEngineDynamic (g, params, {in1});
130
+
131
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
132
+ }
133
+
134
+ TEST (Converters, ATenReplication_pad3dTensorConvertsCorrectly) {
135
+ const auto graph = R"IR(
136
+ graph(%0 : Tensor):
137
+ %1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]()
138
+ %2 : Tensor = aten::replication_pad3d(%0, %1)
139
+ return (%2))IR" ;
140
+
141
+ auto g = std::make_shared<torch::jit::Graph>();
142
+ torch::jit::parseIR (graph, g.get ());
143
+
144
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 , 5 , 3 }, {at::kCUDA });
145
+
146
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
147
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
148
+
149
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
150
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
151
+
152
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
153
+ }
154
+
155
+ TEST (Converters, ATenReplication_pad3dRightBottomBackZeroTensorConvertsCorrectly) {
156
+ const auto graph = R"IR(
157
+ graph(%0 : Tensor):
158
+ %1 : int[] = prim::Constant[value=[2, 0, 2, 0, 1, 0]]()
159
+ %2 : Tensor = aten::replication_pad3d(%0, %1)
160
+ return (%2))IR" ;
161
+
162
+ auto g = std::make_shared<torch::jit::Graph>();
163
+ torch::jit::parseIR (graph, g.get ());
164
+
165
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 , 5 , 3 }, {at::kCUDA });
166
+
167
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
168
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
169
+
170
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
171
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
172
+
173
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
174
+ }
175
+
176
+ TEST (Converters, ATenReplication_pad3dTensorConvertsCorrectlyWithDynamic) {
177
+ const auto graph = R"IR(
178
+ graph(%0 : Tensor):
179
+ %1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]()
180
+ %2 : Tensor = aten::replication_pad3d(%0, %1)
181
+ return (%2))IR" ;
182
+
183
+ auto g = std::make_shared<torch::jit::Graph>();
184
+ torch::jit::parseIR (graph, g.get ());
185
+
186
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 4 , 5 , 3 }, {at::kCUDA });
187
+
188
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
189
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
190
+
191
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
192
+ auto trt_results = trtorch::tests::util::RunGraphEngineDynamic (g, params, {in1});
193
+
194
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
195
+ }
0 commit comments