Skip to content

Commit dd34ec1

Browse files
authored
fix: fix the inappropriate lowering pass of aten::to (#1649)
1 parent 1c02c20 commit dd34ec1

File tree

2 files changed

+0
-37
lines changed

2 files changed

+0
-37
lines changed

core/lowering/passes/reduce_to.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@ namespace lowering {
88
namespace passes {
99

1010
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
11-
std::string to_dtype_layout_pattern = R"IR(
12-
graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format):
13-
%out : Tensor = aten::to(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format)
14-
return (%out))IR";
15-
16-
std::string to_dtype_multi_input_pattern = R"IR(
17-
graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format):
18-
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
19-
return (%out))IR";
20-
2111
std::string to_type_as_pattern = R"IR(
2212
graph(%input, %other):
2313
%out : Tensor = aten::type_as(%input, %other)
@@ -30,11 +20,6 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
3020
%out : Tensor = aten::to(%input, %other, %5, %5, %6)
3121
return (%out))IR";
3222

33-
// replace aten::to.dtype_layout with aten::to.dtype
34-
torch::jit::SubgraphRewriter map_aten_dtype_layout;
35-
map_aten_dtype_layout.RegisterRewritePattern(to_dtype_layout_pattern, to_dtype_multi_input_pattern);
36-
map_aten_dtype_layout.runOnGraph(graph);
37-
3823
// replace aten::type_as with aten::to.other
3924
torch::jit::SubgraphRewriter map_aten_type_as_to_other;
4025
map_aten_type_as_to_other.RegisterRewritePattern(to_type_as_pattern, to_other_pattern);

tests/core/lowering/test_reduce_to_pass.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,6 @@
66
#include "torch/csrc/jit/ir/irparser.h"
77
#include "torch/csrc/jit/ir/subgraph_matcher.h"
88

9-
TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) {
10-
std::string source_graph = R"IR(
11-
graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format):
12-
%out : Tensor = aten::to(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format)
13-
return (%out))IR";
14-
std::string target_graph = R"IR(
15-
graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format):
16-
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
17-
return (%out))IR";
18-
19-
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
20-
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
21-
auto sg = std::make_shared<torch::jit::Graph>();
22-
torch::jit::parseIR(source_graph, &*sg);
23-
torch_tensorrt::core::lowering::passes::ReduceToOperation(sg);
24-
25-
auto tg = std::make_shared<torch::jit::Graph>();
26-
torch::jit::parseIR(target_graph, &*tg);
27-
28-
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
29-
}
30-
319
TEST(LoweringPasses, ReduceAtenTypeAsCorrectly) {
3210
std::string source_graph = R"IR(
3311
graph(%input, %other):

0 commit comments

Comments
 (0)