Skip to content

Commit d291131

Browse files
authored
Merge pull request #318 from NVIDIA/true_divide_lowering
True divide lowering
2 parents 7e467a6 + 7ffe6b6 commit d291131

File tree

7 files changed

+70
-2
lines changed

7 files changed

+70
-2
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ auto element_wise_registrations TRTORCH_UNUSED =
340340
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
341341
return true;
342342
}})
343-
.pattern({"aten::div_.Scalar(Tensor self, Scalar other) -> (Tensor)",
343+
.pattern({"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
344344
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
345345
auto self = args[0].ITensorOrFreeze(ctx);
346346
auto otherScalar = args[1].unwrapToScalar().to<float>();

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4848
// passes::UnpackBatchNorm(g);
4949
passes::UnpackLogSoftmax(g);
5050
passes::RemoveNOPs(g);
51+
passes::AliasOperators(g);
5152
torch::jit::EliminateDeadCode(g);
5253
LOG_GRAPH(*g);
5354
}

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ cc_library(
2525
"unpack_addmm.cpp",
2626
"unpack_batch_norm.cpp",
2727
"unpack_log_softmax.cpp",
28+
"op_aliasing.cpp"
2829
],
2930
deps = [
3031
"//core/util:prelude",

core/lowering/passes/op_aliasing.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string true_divide_pattern = R"IR(
12+
graph(%s, %o):
13+
%1 : Tensor = aten::true_divide(%s, %o)
14+
return (%1))IR";
15+
std::string div_pattern = R"IR(
16+
graph(%s, %o):
17+
%1 : Tensor = aten::div(%s, %o)
18+
return (%1))IR";
19+
;
20+
21+
// TODO
22+
// complete other element wise pass
23+
24+
torch::jit::SubgraphRewriter true_divide_to_div;
25+
true_divide_to_div.RegisterRewritePattern(true_divide_pattern, div_pattern);
26+
true_divide_to_div.runOnGraph(graph);
27+
LOG_GRAPH("Post map true_divide -> div: " << *graph);
28+
}
29+
30+
} // namespace passes
31+
} // namespace lowering
32+
} // namespace core
33+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
1919
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2020
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
22+
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
2223

2324
} // namespace passes
2425
} // namespace lowering

tests/core/lowering/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@ lowering_test(
1919
name = "test_remove_detach_pass",
2020
)
2121

22+
lowering_test(
23+
name = "test_operator_aliasing_pass",
24+
)
25+
2226
test_suite(
2327
name = "lowering_tests",
2428
tests = [
2529
":test_remove_contiguous_pass",
2630
":test_remove_to",
27-
":test_remove_detach_pass"
31+
":test_remove_detach_pass",
32+
":test_operator_aliasing_pass"
2833
]
2934
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, LoweringTrueDivideCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%s, %o):
12+
%2 = aten::true_divide(%s, %o)
13+
return (%2))IR";
14+
std::string target_graph = R"IR(
15+
graph(%s, %o):
16+
%2 = aten::div(%s, %o)
17+
return (%2))IR";
18+
19+
auto sg = std::make_shared<torch::jit::Graph>();
20+
torch::jit::parseIR(source_graph, &*sg);
21+
trtorch::core::lowering::passes::ElementWisePass(sg);
22+
23+
auto tg = std::make_shared<torch::jit::Graph>();
24+
torch::jit::parseIR(target_graph, &*tg);
25+
26+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
27+
}

0 commit comments

Comments
 (0)