Skip to content

Commit d1270d9

Browse files
committed
refactor(//tests): Relax threshold for module tests
Perhaps revisit in the future to see if we can reduce this back Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 7622a97 commit d1270d9

13 files changed

+122
-93
lines changed

tests/core/converters/test_conv.cpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,24 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7-
// aten::_convolution(Tensor input, Tensor weight,
7+
// aten::_convolution(Tensor input, Tensor weight,
88
// Tensor? bias, int[] stride, int[] padding,
9-
// int[] dilation, bool transposed,
10-
// int[] output_padding, int groups, bool benchmark,
9+
// int[] dilation, bool transposed,
10+
// int[] output_padding, int groups, bool benchmark,
1111
// bool deterministic, bool cudnn_enabled) -> (Tensor)
1212

1313
void conv_test_helper(std::string graph_ir) {
1414
auto g = std::make_shared<torch::jit::Graph>();
1515
torch::jit::script::parseIR(graph_ir, &*g);
16-
16+
1717
auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA});
1818
auto w = at::randint(1, 10, {8, 3, 5, 5}, {at::kCUDA});
1919
auto b = at::randint(1, 10, {8}, {at::kCUDA});
2020

2121
auto jit_in = at::clone(in);
2222
auto jit_w = at::clone(w);
2323
auto jit_b = at::clone(b);
24-
24+
2525
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
2626
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
2727

@@ -31,11 +31,11 @@ void conv_test_helper(std::string graph_ir) {
3131
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
3232
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
3333

34-
auto trt = trt_results[0].reshape(jit_results[0].sizes());
34+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
3535

36-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt));
36+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
3737
}
38-
38+
3939
TEST(Converters, ATenConvolutionConvertsCorrectly) {
4040
const auto graph = R"IR(
4141
graph(%0 : Tensor,
@@ -45,7 +45,7 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) {
4545
%4 : int = prim::Constant[value=0]()
4646
%5 : int = prim::Constant[value=1]()
4747
%6 : int = prim::Constant[value=0]()
48-
%7 : bool = prim::Constant[value=0]()
48+
%7 : bool = prim::Constant[value=0]()
4949
%8 : int[] = prim::ListConstruct(%3, %3)
5050
%9 : int[] = prim::ListConstruct(%4, %4)
5151
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -55,15 +55,15 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) {
5555

5656
auto g = std::make_shared<torch::jit::Graph>();
5757
torch::jit::script::parseIR(graph, &*g);
58-
58+
5959
auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA});
6060
auto w = at::randint(1, 10, {8, 3, 5, 5}, {at::kCUDA});
6161
auto b = at::randint(1, 10, {8}, {at::kCUDA});
6262

6363
auto jit_in = at::clone(in);
6464
auto jit_w = at::clone(w);
6565
auto jit_b = at::clone(b);
66-
66+
6767
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
6868
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
6969

@@ -73,9 +73,9 @@ TEST(Converters, ATenConvolutionConvertsCorrectly) {
7373
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
7474
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
7575

76-
auto trt = trt_results[0].reshape(jit_results[0].sizes());
76+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
7777

78-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt));
78+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
7979
}
8080

8181
TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
@@ -87,7 +87,7 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
8787
%4 : int = prim::Constant[value=0]()
8888
%5 : int = prim::Constant[value=1]()
8989
%6 : int = prim::Constant[value=0]()
90-
%7 : bool = prim::Constant[value=0]()
90+
%7 : bool = prim::Constant[value=0]()
9191
%8 : int[] = prim::ListConstruct(%3, %3)
9292
%9 : int[] = prim::ListConstruct(%4, %4)
9393
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -97,12 +97,12 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
9797

9898
auto g = std::make_shared<torch::jit::Graph>();
9999
torch::jit::script::parseIR(graph, &*g);
100-
100+
101101
auto in = at::randint(1, 2, {1, 1, 3, 3}, {at::kCUDA});
102102
auto w = at::randint(1, 2, {4, 1, 2, 2}, {at::kCUDA});
103103

104104
auto jit_in = at::clone(in);
105-
auto jit_w = at::clone(w);
105+
auto jit_w = at::clone(w);
106106
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w});
107107
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
108108

@@ -111,9 +111,9 @@ TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
111111
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w});
112112
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
113113

114-
auto trt = trt_results[0].reshape(jit_results[0].sizes());
114+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
115115

116-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt));
116+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
117117
}
118118

119119

@@ -126,7 +126,7 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) {
126126
%4 : int = prim::Constant[value=0]()
127127
%5 : int = prim::Constant[value=1]()
128128
%6 : int = prim::Constant[value=0]()
129-
%7 : bool = prim::Constant[value=0]()
129+
%7 : bool = prim::Constant[value=0]()
130130
%8 : int[] = prim::ListConstruct(%3, %3)
131131
%9 : int[] = prim::ListConstruct(%4, %4)
132132
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -137,15 +137,15 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) {
137137

138138
auto g = std::make_shared<torch::jit::Graph>();
139139
torch::jit::script::parseIR(graph, &*g);
140-
140+
141141
auto in = at::randint(1, 10, {1, 3, 9, 9}, {at::kCUDA});
142142
auto w = at::randint(1, 10, {4, 3, 3, 3}, {at::kCUDA});
143143
auto b = at::randint(1, 10, {4}, {at::kCUDA});
144144

145145
auto jit_in = at::clone(in);
146146
auto jit_w = at::clone(w);
147147
auto jit_b = at::clone(b);
148-
148+
149149
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
150150
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
151151

@@ -155,9 +155,9 @@ TEST(Converters, ATenConvolutionWithStrideConvertsCorrectly) {
155155
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
156156
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
157157

158-
auto trt = trt_results[0].reshape(jit_results[0].sizes());
158+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
159159

160-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt));
160+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
161161
}
162162

163163
TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
@@ -169,7 +169,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
169169
%4 : int = prim::Constant[value=2]()
170170
%5 : int = prim::Constant[value=1]()
171171
%6 : int = prim::Constant[value=0]()
172-
%7 : bool = prim::Constant[value=0]()
172+
%7 : bool = prim::Constant[value=0]()
173173
%8 : int[] = prim::ListConstruct(%3, %3)
174174
%9 : int[] = prim::ListConstruct(%4, %4)
175175
%10 : int[] = prim::ListConstruct(%5, %5)
@@ -180,15 +180,15 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
180180

181181
auto g = std::make_shared<torch::jit::Graph>();
182182
torch::jit::script::parseIR(graph, &*g);
183-
183+
184184
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
185185
auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA});
186186
auto b = at::randint(1, 10, {4}, {at::kCUDA});
187187

188188
auto jit_in = at::clone(in);
189189
auto jit_w = at::clone(w);
190190
auto jit_b = at::clone(b);
191-
191+
192192
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
193193
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
194194

@@ -198,9 +198,9 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
198198
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
199199
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
200200

201-
auto trt = trt_results[0].reshape(jit_results[0].sizes());
201+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
202202

203-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt));
203+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
204204
}
205205

206206
// TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
@@ -212,7 +212,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
212212
// %4 : int = prim::Constant[value=0]()
213213
// %5 : int = prim::Constant[value=2]()
214214
// %6 : int = prim::Constant[value=0]()
215-
// %7 : bool = prim::Constant[value=0]()
215+
// %7 : bool = prim::Constant[value=0]()
216216
// %8 : int[] = prim::ListConstruct(%3, %3)
217217
// %9 : int[] = prim::ListConstruct(%4, %4)
218218
// %10 : int[] = prim::ListConstruct(%5, %5)
@@ -233,7 +233,7 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
233233
// %4 : int = prim::Constant[value=0]()
234234
// %5 : int = prim::Constant[value=1]()
235235
// %6 : int = prim::Constant[value=2]()
236-
// %7 : bool = prim::Constant[value=0]()
236+
// %7 : bool = prim::Constant[value=0]()
237237
// %8 : int[] = prim::ListConstruct(%3, %3)
238238
// %9 : int[] = prim::ListConstruct(%4, %4)
239239
// %10 : int[] = prim::ListConstruct(%5, %5)
@@ -254,14 +254,14 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
254254
// %4 : int = prim::Constant[value=0]()
255255
// %5 : int = prim::Constant[value=1]()
256256
// %6 : int = prim::Constant[value=0]()
257-
// %7 : bool = prim::Constant[value=0]()
257+
// %7 : bool = prim::Constant[value=0]()
258258
// %8 : int[] = prim::ListConstruct(%3, %3)
259259
// %9 : int[] = prim::ListConstruct(%4, %4)
260260
// %10 : int[] = prim::ListConstruct(%5, %5)
261261
// %11 : int[] = prim::ListConstruct(%6, %6)
262262
// %12 : int = prim::Constant[value=2]()
263263
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
264264
// return (%13))IR";
265-
265+
266266
// conv_test_helper(graph);
267267
// }

tests/core/converters/test_element_wise.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
void pointwise_test_helper(std::string graph_ir) {
88
auto g = std::make_shared<torch::jit::Graph>();
99
torch::jit::script::parseIR(graph_ir, &*g);
10-
10+
1111
auto in0 = at::randint(1, 5, {5}, {at::kCUDA});
1212
auto in1 = at::randint(1, 5, {5}, {at::kCUDA});
1313
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
@@ -18,7 +18,7 @@ void pointwise_test_helper(std::string graph_ir) {
1818
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
1919
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0, in1});
2020

21-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
21+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
2222
}
2323

2424

tests/core/converters/test_linear.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7+
#define LAYER_TEST //used to set threshold for diff
78

89
TEST(Converters, ATenLinearNoBiasConvertsCorrectly) {
910
const auto graph = R"IR(
@@ -28,7 +29,7 @@ TEST(Converters, ATenLinearNoBiasConvertsCorrectly) {
2829
params = trtorch::core::conversion::get_named_params(g->inputs(), {w});
2930
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
3031

31-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0])));
32+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3233
}
3334

3435

@@ -62,5 +63,5 @@ TEST(Converters, ATenLinearBiasConvertsCorrectly) {
6263
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
6364

6465

65-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0])));
66+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
6667
}

tests/core/converters/test_pooling.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,34 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7+
#define LAYER_TEST //used to set threshold for diff
8+
79
TEST(Converters, ATenMaxPool2DConvertsCorrectly) {
810
const auto graph = R"IR(
911
graph(%0 : Tensor):
1012
%1 : int = prim::Constant[value=0]()
1113
%2 : int = prim::Constant[value=1]()
1214
%3 : int = prim::Constant[value=2]()
13-
%5 : bool = prim::Constant[value=0]()
15+
%5 : bool = prim::Constant[value=0]()
1416
%6 : int[] = prim::ListConstruct(%1, %1)
1517
%7 : int[] = prim::ListConstruct(%2, %2)
1618
%8 : int[] = prim::ListConstruct(%3, %3)
17-
%10 : Tensor = aten::max_pool2d(%0, %8, %7, %6, %7, %5)
19+
%10 : Tensor = aten::max_pool2d(%0, %8, %7, %6, %7, %5)
1820
return (%10))IR";
1921

2022
auto g = std::make_shared<torch::jit::Graph>();
2123
torch::jit::script::parseIR(graph, &*g);
2224

23-
//PyTorch MaxPool needs a 3D input
25+
//PyTorch MaxPool needs a 3D input
2426
auto in = at::randint(-5, 5, {1, 4, 4}, at::kCUDA);
2527
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
2628
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
2729

2830
in = at::clone(in);
2931
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
3032
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
31-
32-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
33+
34+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
3335
}
3436

3537
TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
@@ -38,13 +40,13 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
3840
%2 : int = prim::Constant[value=3]()
3941
%3 : int = prim::Constant[value=4]()
4042
%6 : int[] = prim::ListConstruct(%2, %3)
41-
%10 : Tensor = aten::adaptive_avg_pool2d(%0, %6)
43+
%10 : Tensor = aten::adaptive_avg_pool2d(%0, %6)
4244
return (%10))IR";
4345

4446
auto g = std::make_shared<torch::jit::Graph>();
4547
torch::jit::script::parseIR(graph, &*g);
4648

47-
//PyTorch MaxPool needs a 3D input
49+
//PyTorch MaxPool needs a 3D input
4850
auto in = at::randint(-5, 5, {1, 12, 16}, at::kCUDA);
4951

5052
auto jit_in = at::clone(in);
@@ -54,6 +56,6 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
5456
auto trt_in = at::clone(in);
5557
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
5658
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
57-
58-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
59+
60+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
5961
}

tests/core/converters/test_reduce.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7+
#define LAYER_TEST //used to set threshold for diff
8+
79
TEST(Converters, ATenMeanConvertsCorrectly) {
810
const auto graph = R"IR(
911
graph(%0 : Tensor):
@@ -22,7 +24,7 @@ TEST(Converters, ATenMeanConvertsCorrectly) {
2224
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
2325
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
2426

25-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
27+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
2628
}
2729

2830
TEST(Converters, ATenMeanHigherDimensionConvertsCorrectly) {
@@ -43,7 +45,7 @@ TEST(Converters, ATenMeanHigherDimensionConvertsCorrectly) {
4345
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
4446
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
4547

46-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
48+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
4749
}
4850

4951
TEST(Converters, ATenMeanRowConvertsCorrectly) {
@@ -67,7 +69,7 @@ TEST(Converters, ATenMeanRowConvertsCorrectly) {
6769
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
6870
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
6971

70-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
72+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
7173
}
7274

7375
TEST(Converters, ATenMeanMultiDimsConvertsCorrectly) {
@@ -92,7 +94,7 @@ TEST(Converters, ATenMeanMultiDimsConvertsCorrectly) {
9294
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
9395
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
9496

95-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
97+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
9698
}
9799

98100
TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
@@ -116,5 +118,5 @@ TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
116118
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
117119
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
118120

119-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
121+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
120122
}

0 commit comments

Comments
 (0)