Skip to content

Commit 36c7d96

Browse files
committed
support erf/asinh/acosh/atanh
Signed-off-by: inocsin <[email protected]>
1 parent 2b50334 commit 36c7d96

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

core/conversion/converters/impl/unary.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ convert(ceil, kCEIL);
4040
convert(sqrt, kSQRT);
4141
convert(exp, kEXP);
4242
convert(neg, kNEG);
43+
convert(erf, kERF);
44+
convert(asinh, kASINH);
45+
convert(acosh, kACOSH);
46+
convert(atanh, kATANH);
4347

4448
#undef convert
4549

tests/core/conversion/converters/test_unary.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ std::string gen_test_graph(const std::string& unary) {
2121
auto g = std::make_shared<torch::jit::Graph>(); \
2222
torch::jit::parseIR(graph, &*g); \
2323
\
24-
auto in = at::empty({10}, {at::kCUDA}).uniform_(0, 0.5); \
24+
float offset = 0; \
25+
if (#name == "Acosh") offset += 1; /*input larger than 1 for acosh*/ \
26+
auto in = at::empty({10}, {at::kCUDA}).uniform_(0+offset, 0.5+offset); \
2527
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \
2628
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); \
2729
\
@@ -48,5 +50,9 @@ test_unary(ceil, Ceil);
4850
test_unary(sqrt, Sqrt);
4951
test_unary(exp, Exp);
5052
test_unary(neg, Neg);
53+
test_unary(erf, Erf);
54+
test_unary(asinh, Asinh);
55+
test_unary(acosh, Acosh);
56+
test_unary(atanh, Atanh);
5157

5258
#undef test_unary

0 commit comments

Comments
 (0)