Skip to content

Commit 054373b

Browse files
committed
Merge branch 'guoruoqian-support_eq_str' into 'release/1.0'
feat: support aten::eq.str evaluator See merge request adlsa/TRTorch!9
2 parents ff10ff4 + 6d73e43 commit 054373b

File tree

5 files changed

+89
-1
lines changed

5 files changed

+89
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
3333
"aten::eq.bool(bool a, bool b) -> (bool)",
3434
"aten::eq.int(int a, int b) -> (bool)",
3535
"aten::eq.float(float a, float b) -> (bool)",
36+
"aten::eq.str(str a, str b) -> (bool)",
3637
"aten::eq.int_float(int a, float b) -> (bool)",
3738
"aten::eq.float_int(float a, int b) -> (bool)",
3839
}));
@@ -97,7 +98,12 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
9798
"aten::ge.float_int(float a, int b) -> (bool)",
9899
}));
99100

100-
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a&& b, bool, {"aten::__and__(int a, int b) -> (bool)"});
101+
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
102+
and,
103+
"aten::__and__",
104+
a&& b,
105+
bool,
106+
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
101107
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});
102108
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
103109
xor,

core/conversion/evaluators/eval_macros.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@
5757
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
5858
return {}; \
5959
} \
60+
} else if (args.at(n->input(0)).IValue()->isString()) { \
61+
auto a = args.at(n->input(0)).unwrapToString(); \
62+
if (args.at(n->input(1)).IValue()->isString()) { \
63+
auto b = args.at(n->input(1)).unwrapToString(); \
64+
return operation; \
65+
} else { \
66+
TRTORCH_THROW_ERROR( \
67+
"Unimplemented data type for " \
68+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
69+
return {}; \
70+
} \
6071
} else { \
6172
TORCHTRT_THROW_ERROR( \
6273
"Unimplemented data type for " \

core/conversion/var/Var.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class Var : torch::CustomClassHolder {
3636
double unwrapToDouble();
3737
bool unwrapToBool(bool default_val);
3838
bool unwrapToBool();
39+
std::string unwrapToString(std::string default_val);
40+
std::string unwrapToString();
3941
c10::Scalar unwrapToScalar(c10::Scalar default_val);
4042
c10::Scalar unwrapToScalar();
4143
c10::List<int64_t> unwrapToIntList(c10::List<int64_t> default_val);

core/conversion/var/Var_inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ DEFINE_UNWRAP_TO(at::Tensor, Tensor)
3838
DEFINE_UNWRAP_TO(int64_t, Int)
3939
DEFINE_UNWRAP_TO(double, Double)
4040
DEFINE_UNWRAP_TO(bool, Bool)
41+
DEFINE_UNWRAP_TO(std::string, String)
4142
DEFINE_UNWRAP_TO(c10::Scalar, Scalar)
4243
DEFINE_UNWRAP_TO(c10::List<int64_t>, IntList)
4344
DEFINE_UNWRAP_TO(c10::List<double>, DoubleList)

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,5 +510,73 @@ TEST(Evaluators, ATenIsFloatingPointEvaluatesFalseCorrectly) {
510510
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
511511
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in_trt});
512512

513+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
514+
}
515+
516+
TEST(Evaluators, EqStrResultIsTrueEvaluatesCorrectly) {
517+
const auto graph = R"IR(
518+
graph():
519+
%1 : str = prim::Constant[value="res3"]()
520+
%2 : str = prim::Constant[value="res3"]()
521+
%3 : bool = aten::eq(%1, %2)
522+
return (%3))IR";
523+
524+
auto g = std::make_shared<torch::jit::Graph>();
525+
torch::jit::parseIR(graph, g.get());
526+
527+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
528+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
529+
530+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
531+
}
532+
533+
TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) {
534+
const auto graph = R"IR(
535+
graph():
536+
%1 : str = prim::Constant[value="res3"]()
537+
%2 : str = prim::Constant[value="res4"]()
538+
%3 : bool = aten::eq(%1, %2)
539+
return (%3))IR";
540+
541+
auto g = std::make_shared<torch::jit::Graph>();
542+
torch::jit::parseIR(graph, g.get());
543+
544+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
545+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
546+
547+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
548+
}
549+
550+
TEST(Evaluators, AndBoolResultIsTrueEvaluatesCorrectly) {
551+
const auto graph = R"IR(
552+
graph():
553+
%1 : bool = prim::Constant[value=1]()
554+
%2 : bool = prim::Constant[value=1]()
555+
%3 : bool = aten::__and__(%1, %2)
556+
return (%3))IR";
557+
558+
auto g = std::make_shared<torch::jit::Graph>();
559+
torch::jit::parseIR(graph, g.get());
560+
561+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
562+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
563+
564+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
565+
}
566+
567+
TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
568+
const auto graph = R"IR(
569+
graph():
570+
%1 : bool = prim::Constant[value=1]()
571+
%2 : bool = prim::Constant[value=0]()
572+
%3 : bool = aten::__and__(%1, %2)
573+
return (%3))IR";
574+
575+
auto g = std::make_shared<torch::jit::Graph>();
576+
torch::jit::parseIR(graph, g.get());
577+
578+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
579+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
580+
513581
ASSERT_TRUE(jit_results[0] == trt_results[0]);
514582
}

0 commit comments

Comments
 (0)