Skip to content

Commit 7ec5c73

Browse files
committed
chore: add uTest about aten::format
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 3a33d33 commit 7ec5c73

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "gtest/gtest.h"
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
6+
#include "torch/csrc/jit/runtime/jit_exception.h"
67
#include "torch/torch.h"
78

89
TEST(Evaluators, DivIntEvaluatesCorrectly) {
@@ -613,4 +614,55 @@ TEST(Evaluators, AtenFormatEvaluatesCorrectly) {
613614

614615
ASSERT_TRUE(
615616
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
617+
}
618+
619+
TEST(Evaluators, AtenFormatRaiseExceptionEvaluatesCorrectly) {
620+
const auto graph = R"IR(
621+
graph(%x_1 : Tensor, %x_2 : Tensor):
622+
%0 : int = prim::Constant[value=1]()
623+
%1 : str = prim::Constant[value="res5_1"]()
624+
%2 : str = prim::Constant[value="{} is not equal to {}"]()
625+
%3 : str = prim::Constant[value="res5_2"]()
626+
%5713 : Tensor = prim::Uninitialized()
627+
%4 : str = aten::format(%2, %1, %3)
628+
%5 : bool = aten::eq(%1, %3)
629+
%y : Tensor = prim::If(%5)
630+
block0():
631+
%194 : Tensor = aten::add(%x_1, %x_2, %0)
632+
-> (%194)
633+
block1():
634+
prim::RaiseException(%4)
635+
-> (%5713)
636+
return (%y))IR";
637+
auto g = std::make_shared<torch::jit::Graph>();
638+
torch::jit::parseIR(graph, &*g);
639+
640+
auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
641+
auto in1 = in0.clone();
642+
643+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
644+
std::vector<at::Tensor> jit_results, trt_results;
645+
std::string error_jit, error_torch_trt;
646+
try {
647+
jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});
648+
} catch (const torch::jit::JITException& error) {
649+
error_jit = error.what();
650+
}
651+
652+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
653+
try {
654+
trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});
655+
} catch (const torch_tensorrt::Error& error) {
656+
error_torch_trt = error.what();
657+
}
658+
659+
auto position1 = error_jit.find("RuntimeError:");
660+
auto position2 = error_torch_trt.find("Error from TorchScript:");
661+
std::string jit_msg = error_jit.substr(position1 + 13);
662+
std::string torch_trt_msg = error_torch_trt.substr(position2 + 23);
663+
if (jit_msg == torch_trt_msg) {
664+
ASSERT_TRUE(true);
665+
} else {
666+
ASSERT_TRUE(false);
667+
}
616668
}

0 commit comments

Comments
 (0)