File tree Expand file tree Collapse file tree 7 files changed +70
-2
lines changed
conversion/converters/impl Expand file tree Collapse file tree 7 files changed +70
-2
lines changed Original file line number Diff line number Diff line change @@ -340,7 +340,7 @@ auto element_wise_registrations TRTORCH_UNUSED =
340
340
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
341
341
return true ;
342
342
}})
343
- .pattern({" aten::div_.Scalar(Tensor self, Scalar other) -> ( Tensor)" ,
343
+ .pattern({" aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a! )" ,
344
344
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
345
345
auto self = args[0 ].ITensorOrFreeze (ctx);
346
346
auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
48
48
// passes::UnpackBatchNorm(g);
49
49
passes::UnpackLogSoftmax (g);
50
50
passes::RemoveNOPs (g);
51
+ passes::AliasOperators (g);
51
52
torch::jit::EliminateDeadCode (g);
52
53
LOG_GRAPH (*g);
53
54
}
Original file line number Diff line number Diff line change @@ -25,6 +25,7 @@ cc_library(
25
25
"unpack_addmm.cpp" ,
26
26
"unpack_batch_norm.cpp" ,
27
27
"unpack_log_softmax.cpp" ,
28
+ "op_aliasing.cpp"
28
29
],
29
30
deps = [
30
31
"//core/util:prelude" ,
Original file line number Diff line number Diff line change
1
+ #include < torch/csrc/jit/passes/subgraph_rewrite.h>
2
+
3
+ #include " core/util/prelude.h"
4
+
5
+ namespace trtorch {
6
+ namespace core {
7
+ namespace lowering {
8
+ namespace passes {
9
+
10
+ void AliasOperators (std::shared_ptr<torch::jit::Graph>& graph) {
11
+ std::string true_divide_pattern = R"IR(
12
+ graph(%s, %o):
13
+ %1 : Tensor = aten::true_divide(%s, %o)
14
+ return (%1))IR" ;
15
+ std::string div_pattern = R"IR(
16
+ graph(%s, %o):
17
+ %1 : Tensor = aten::div(%s, %o)
18
+ return (%1))IR" ;
19
+ ;
20
+
21
+ // TODO
22
+ // complete other element wise pass
23
+
24
+ torch::jit::SubgraphRewriter true_divide_to_div;
25
+ true_divide_to_div.RegisterRewritePattern (true_divide_pattern, div_pattern);
26
+ true_divide_to_div.runOnGraph (graph);
27
+ LOG_GRAPH (" Post map true_divide -> div: " << *graph);
28
+ }
29
+
30
+ } // namespace passes
31
+ } // namespace lowering
32
+ } // namespace core
33
+ } // namespace trtorch
Original file line number Diff line number Diff line change @@ -19,6 +19,7 @@ void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
19
19
void UnpackAddMM (std::shared_ptr<torch::jit::Graph>& graph);
20
20
void UnpackBatchNorm (std::shared_ptr<torch::jit::Graph>& graph);
21
21
void UnpackLogSoftmax (std::shared_ptr<torch::jit::Graph>& graph);
22
+ void AliasOperators (std::shared_ptr<torch::jit::Graph>& graph);
22
23
23
24
} // namespace passes
24
25
} // namespace lowering
Original file line number Diff line number Diff line change @@ -19,11 +19,16 @@ lowering_test(
19
19
name = "test_remove_detach_pass" ,
20
20
)
21
21
22
+ lowering_test (
23
+ name = "test_operator_aliasing_pass" ,
24
+ )
25
+
22
26
test_suite (
23
27
name = "lowering_tests" ,
24
28
tests = [
25
29
":test_remove_contiguous_pass" ,
26
30
":test_remove_to" ,
27
- ":test_remove_detach_pass"
31
+ ":test_remove_detach_pass" ,
32
+ ":test_operator_aliasing_pass"
28
33
]
29
34
)
Original file line number Diff line number Diff line change
1
+ #include < string>
2
+ #include " core/compiler.h"
3
+ #include " core/lowering/passes/passes.h"
4
+ #include " gtest/gtest.h"
5
+ #include " tests/util/util.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
7
+ #include " torch/csrc/jit/ir/subgraph_matcher.h"
8
+
9
+ TEST (LoweringPasses, LoweringTrueDivideCorrectly) {
10
+ std::string source_graph = R"IR(
11
+ graph(%s, %o):
12
+ %2 = aten::true_divide(%s, %o)
13
+ return (%2))IR" ;
14
+ std::string target_graph = R"IR(
15
+ graph(%s, %o):
16
+ %2 = aten::div(%s, %o)
17
+ return (%2))IR" ;
18
+
19
+ auto sg = std::make_shared<torch::jit::Graph>();
20
+ torch::jit::parseIR (source_graph, &*sg);
21
+ trtorch::core::lowering::passes::ElementWisePass (sg);
22
+
23
+ auto tg = std::make_shared<torch::jit::Graph>();
24
+ torch::jit::parseIR (target_graph, &*tg);
25
+
26
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
27
+ }
You can’t perform that action at this time.
0 commit comments