3
3
#include " gtest/gtest.h"
4
4
#include " tests/util/util.h"
5
5
#include " torch/csrc/jit/ir/irparser.h"
6
+ #include " torch/csrc/jit/runtime/jit_exception.h"
6
7
#include " torch/torch.h"
7
8
8
9
TEST (Evaluators, DivIntEvaluatesCorrectly) {
@@ -613,4 +614,55 @@ TEST(Evaluators, AtenFormatEvaluatesCorrectly) {
613
614
614
615
ASSERT_TRUE (
615
616
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
617
+ }
618
+
619
+ TEST (Evaluators, AtenFormatRaiseExceptionEvaluatesCorrectly) {
620
+ const auto graph = R"IR(
621
+ graph(%x_1 : Tensor, %x_2 : Tensor):
622
+ %0 : int = prim::Constant[value=1]()
623
+ %1 : str = prim::Constant[value="res5_1"]()
624
+ %2 : str = prim::Constant[value="{} is not equal to {}"]()
625
+ %3 : str = prim::Constant[value="res5_2"]()
626
+ %5713 : Tensor = prim::Uninitialized()
627
+ %4 : str = aten::format(%2, %1, %3)
628
+ %5 : bool = aten::eq(%1, %3)
629
+ %y : Tensor = prim::If(%5)
630
+ block0():
631
+ %194 : Tensor = aten::add(%x_1, %x_2, %0)
632
+ -> (%194)
633
+ block1():
634
+ prim::RaiseException(%4)
635
+ -> (%5713)
636
+ return (%y))IR" ;
637
+ auto g = std::make_shared<torch::jit::Graph>();
638
+ torch::jit::parseIR (graph, &*g);
639
+
640
+ auto in0 = at::randint (1 , 10 , {3 , 4 }, {at::kCUDA });
641
+ auto in1 = in0.clone ();
642
+
643
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
644
+ std::vector<at::Tensor> jit_results, trt_results;
645
+ std::string error_jit, error_torch_trt;
646
+ try {
647
+ jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in0, in1});
648
+ } catch (const torch::jit::JITException& error) {
649
+ error_jit = error.what ();
650
+ }
651
+
652
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
653
+ try {
654
+ trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in0, in1});
655
+ } catch (const torch_tensorrt::Error& error) {
656
+ error_torch_trt = error.what ();
657
+ }
658
+
659
+ auto position1 = error_jit.find (" RuntimeError:" );
660
+ auto position2 = error_torch_trt.find (" Error from TorchScript:" );
661
+ std::string jit_msg = error_jit.substr (position1 + 13 );
662
+ std::string torch_trt_msg = error_torch_trt.substr (position2 + 23 );
663
+ if (jit_msg == torch_trt_msg) {
664
+ ASSERT_TRUE (true );
665
+ } else {
666
+ ASSERT_TRUE (false );
667
+ }
616
668
}
0 commit comments