Skip to content

Commit 6a12934

Browse files
authored
Merge pull request #901 from Njuapp/aten_remainder
Add remainders' op
2 parents 759664d + bbd074c commit 6a12934

File tree

7 files changed

+163
-0
lines changed

7 files changed

+163
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4444
passes::EliminateExceptionOrPassPattern(g);
4545
passes::ReduceToOperation(g);
4646
passes::ReduceGelu(g);
47+
passes::ReduceRemainder(g);
4748
passes::RemoveContiguous(g);
4849
passes::ViewToReshape(g);
4950
passes::RemoveDropout(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
"op_aliasing.cpp",
1919
"reduce_to.cpp",
2020
"reduce_gelu.cpp",
21+
"reduce_remainder.cpp",
2122
"remove_bn_dim_check.cpp",
2223
"remove_contiguous.cpp",
2324
"view_to_reshape.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2121
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
2222
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
2323
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
24+
void ReduceRemainder(std::shared_ptr<torch::jit::Graph>& graph);
2425
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
2526
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
2627
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "core/util/prelude.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
9+
void ReduceRemainder(std::shared_ptr<torch::jit::Graph>& graph) {
10+
std::string remainder_pattern = R"IR(
11+
graph(%self : Tensor, %other : Tensor):
12+
%out : Tensor = aten::remainder(%self, %other)
13+
return (%out))IR";
14+
15+
std::string remainder_reduce_pattern = R"IR(
16+
graph(%self : Tensor, %other : Tensor):
17+
%alpha : int = prim::Constant[value=1]()
18+
%floor: Tensor = aten::floor_divide(%self, %other)
19+
%prod: Tensor = aten::mul(%floor, %other)
20+
%out: Tensor = aten::sub(%self, %prod, %alpha)
21+
return (%out))IR";
22+
23+
std::string remainder_scalar_pattern = R"IR(
24+
graph(%self : Tensor, %other : Scalar):
25+
%out : Tensor = aten::remainder(%self, %other)
26+
return (%out))IR";
27+
28+
std::string remainder_scalar_reduce_pattern = R"IR(
29+
graph(%self : Tensor, %other : Scalar):
30+
%alpha : int = prim::Constant[value=1]()
31+
%floor: Tensor = aten::floor_divide(%self, %other)
32+
%prod: Tensor = aten::mul(%floor, %other)
33+
%out: Tensor = aten::sub(%self, %prod, %alpha)
34+
return (%out))IR";
35+
36+
// replace aten::remainder with pointwise operations
37+
torch::jit::SubgraphRewriter map_remainder_to_pointwise_ops;
38+
map_remainder_to_pointwise_ops.RegisterRewritePattern(remainder_pattern, remainder_reduce_pattern);
39+
map_remainder_to_pointwise_ops.RegisterRewritePattern(remainder_scalar_pattern, remainder_scalar_reduce_pattern);
40+
map_remainder_to_pointwise_ops.runOnGraph(graph);
41+
42+
LOG_GRAPH("Post lowering of [aten::remainder] -> " << *graph);
43+
}
44+
45+
} // namespace passes
46+
} // namespace lowering
47+
} // namespace core
48+
} // namespace torch_tensorrt

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <string>
22
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
34
#include "gtest/gtest.h"
45
#include "tests/util/util.h"
56
#include "torch/csrc/jit/ir/irparser.h"
@@ -423,3 +424,52 @@ TEST(Converters, ATenLEScalarConvertsCorrectly) {
423424
return (%2))IR";
424425
pointwise_test_helper(graph, true, false, {5, 5});
425426
}
427+
428+
TEST(Converters, ATenRemainderConvertsCorrectly) {
429+
const auto graph = R"IR(
430+
graph(%0 : Tensor, %1 : Tensor):
431+
%2 : Tensor = aten::remainder(%0, %1)
432+
return (%2))IR";
433+
434+
auto g = std::make_shared<torch::jit::Graph>();
435+
torch::jit::parseIR(graph, &*g);
436+
437+
auto input1 = at::randint(-5, 5, {4, 5}, {at::kCUDA});
438+
auto input2 = at::randint(1, 5, {5}, {at::kCUDA});
439+
440+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
441+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input1, input2});
442+
443+
torch_tensorrt::core::lowering::passes::ReduceRemainder(g);
444+
445+
input1 = at::clone(input1);
446+
input2 = at::clone(input2);
447+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
448+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {input1, input2});
449+
450+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 5e-2));
451+
}
452+
453+
TEST(Converters, ATenRemainderWithScalarConvertsCorrectly) {
454+
const auto graph = R"IR(
455+
graph(%0 : Tensor):
456+
%scalar : float = prim::Constant[value=2.4]()
457+
%1 : Tensor = aten::remainder(%0, %scalar)
458+
return (%1))IR";
459+
460+
auto g = std::make_shared<torch::jit::Graph>();
461+
torch::jit::parseIR(graph, &*g);
462+
463+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
464+
465+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
466+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
467+
468+
torch_tensorrt::core::lowering::passes::ReduceRemainder(g);
469+
470+
in = at::clone(in);
471+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
472+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
473+
474+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
475+
}

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ lowering_test(
4646
name = "test_reduce_gelu",
4747
)
4848

49+
lowering_test(
50+
name = "test_reduce_remainder",
51+
)
52+
4953
lowering_test(
5054
name = "test_remove_detach_pass",
5155
)
@@ -82,6 +86,7 @@ test_suite(
8286
":test_view_to_reshape_pass",
8387
":test_remove_dropout_pass",
8488
":test_reduce_to_pass",
89+
":test_reduce_remainder",
8590
":test_reduce_gelu",
8691
":test_unpack_hardswish",
8792
":test_unpack_reduce_ops"
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, ReduceRemainderCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%self : Tensor, %other : Tensor):
12+
%out : Tensor = aten::remainder(%self, %other)
13+
return (%out))IR";
14+
std::string target_graph = R"IR(
15+
graph(%self : Tensor, %other : Tensor):
16+
%alpha : int = prim::Constant[value=1]()
17+
%floor: Tensor = aten::floor_divide(%self, %other)
18+
%prod: Tensor = aten::mul(%floor, %other)
19+
%out: Tensor = aten::sub(%self, %prod, %alpha)
20+
return (%out))IR";
21+
22+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
23+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
24+
auto sg = std::make_shared<torch::jit::Graph>();
25+
torch::jit::parseIR(source_graph, &*sg);
26+
torch_tensorrt::core::lowering::passes::ReduceRemainder(sg);
27+
28+
auto tg = std::make_shared<torch::jit::Graph>();
29+
torch::jit::parseIR(target_graph, &*tg);
30+
31+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
32+
}
33+
34+
TEST(LoweringPasses, ReduceRemainderScalarCorrectly) {
35+
std::string source_graph = R"IR(
36+
graph(%self : Tensor, %other : Scalar):
37+
%out : Tensor = aten::remainder(%self, %other)
38+
return (%out))IR";
39+
std::string target_graph = R"IR(
40+
graph(%self : Tensor, %other : Scalar):
41+
%alpha : int = prim::Constant[value=1]()
42+
%floor: Tensor = aten::floor_divide(%self, %other)
43+
%prod: Tensor = aten::mul(%floor, %other)
44+
%out: Tensor = aten::sub(%self, %prod, %alpha)
45+
return (%out))IR";
46+
47+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
48+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
49+
auto sg = std::make_shared<torch::jit::Graph>();
50+
torch::jit::parseIR(source_graph, &*sg);
51+
torch_tensorrt::core::lowering::passes::ReduceRemainder(sg);
52+
53+
auto tg = std::make_shared<torch::jit::Graph>();
54+
torch::jit::parseIR(target_graph, &*tg);
55+
56+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
57+
}

0 commit comments

Comments
 (0)