@@ -42,7 +42,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
42
42
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
43
43
auto trt_out = trt_mod.forward (inputs_);
44
44
45
- ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor (), 0.99 ));
45
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor ()));
46
46
}
47
47
48
48
TEST (CppAPITests, TestCollectionTupleInput) {
@@ -85,7 +85,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
85
85
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
86
86
auto trt_out = trt_mod.forward (complex_inputs);
87
87
88
- ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor (), 0.99 ));
88
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor ()));
89
89
}
90
90
91
91
TEST (CppAPITests, TestCollectionListInput) {
@@ -144,7 +144,7 @@ TEST(CppAPITests, TestCollectionListInput) {
144
144
LOG_DEBUG (" Finish compile" );
145
145
auto trt_out = trt_mod.forward (complex_inputs);
146
146
147
- ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor (), 0.99 ));
147
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor ()));
148
148
}
149
149
150
150
TEST (CppAPITests, TestCollectionTupleInputOutput) {
@@ -178,23 +178,18 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {
178
178
torch::jit::IValue complex_input_shape (input_shape_tuple);
179
179
std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
180
180
torch::jit::IValue complex_input_shape2 (input_tuple2);
181
- // torch::jit::IValue complex_input_shape(list);
182
181
183
182
auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
184
183
compile_settings.min_block_size = 1 ;
185
184
186
- // compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");
187
-
188
185
// // FP16 execution
189
186
compile_settings.enabled_precisions = {torch::kHalf };
190
187
// // Compile module
191
188
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
192
189
auto trt_out = trt_mod.forward (complex_inputs);
193
190
194
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
195
- out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor (), 1e-4 ));
196
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
197
- out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor (), 1e-4 ));
191
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor ()));
192
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor ()));
198
193
}
199
194
200
195
TEST (CppAPITests, TestCollectionListInputOutput) {
@@ -252,10 +247,8 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
252
247
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
253
248
auto trt_out = trt_mod.forward (complex_inputs);
254
249
255
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
256
- out.toList ().vec ()[0 ].toTensor (), trt_out.toList ().vec ()[0 ].toTensor (), 1e-5 ));
257
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
258
- out.toList ().vec ()[1 ].toTensor (), trt_out.toList ().vec ()[1 ].toTensor (), 1e-5 ));
250
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toList ().vec ()[0 ].toTensor (), trt_out.toList ().vec ()[0 ].toTensor ()));
251
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toList ().vec ()[1 ].toTensor (), trt_out.toList ().vec ()[1 ].toTensor ()));
259
252
}
260
253
261
254
TEST (CppAPITests, TestCollectionComplexModel) {
@@ -313,8 +306,6 @@ TEST(CppAPITests, TestCollectionComplexModel) {
313
306
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
314
307
auto trt_out = trt_mod.forward (complex_inputs);
315
308
316
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
317
- out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor (), 1e-5 ));
318
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
319
- out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor (), 1e-5 ));
309
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor ()));
310
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor ()));
320
311
}
0 commit comments