Skip to content

Commit 68dd005

Browse files
authored
Merge pull request #827 from NVIDIA/peri044/dtype_layout
feat: Implement aten::to.dtype_layout pass
2 parents 3b1ce7c + 543e2a5 commit 68dd005

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

core/lowering/passes/reduce_to.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
1616
graph(%x, %device, %dtype, %nb, %copy, %format):
1717
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
1818
return (%out))IR";
19+
std::string to_dtype_layout_pattern = R"IR(
20+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
21+
%out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format)
22+
return (%out))IR";
23+
24+
std::string to_dtype_multi_input_pattern = R"IR(
25+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
26+
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
27+
return (%out))IR";
1928

2029
std::string to_type_as_pattern = R"IR(
2130
graph(%input, %other):
@@ -34,6 +43,11 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
3443
map_aten_device_to_dtype.RegisterRewritePattern(to_device_pattern, to_dtype_pattern);
3544
map_aten_device_to_dtype.runOnGraph(graph);
3645

46+
// replace aten::to.dtype_layout with aten::to.dtype
47+
torch::jit::SubgraphRewriter map_aten_dtype_layout;
48+
map_aten_dtype_layout.RegisterRewritePattern(to_dtype_layout_pattern, to_dtype_multi_input_pattern);
49+
map_aten_dtype_layout.runOnGraph(graph);
50+
3751
// replace aten::type_as with aten::to.other
3852
torch::jit::SubgraphRewriter map_aten_type_as_to_other;
3953
map_aten_type_as_to_other.RegisterRewritePattern(to_type_as_pattern, to_other_pattern);

tests/core/lowering/test_reduce_to_pass.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ TEST(LoweringPasses, ReduceToCorrectly) {
2828
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
2929
}
3030

31+
TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) {
32+
std::string source_graph = R"IR(
33+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
34+
%out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format)
35+
return (%out))IR";
36+
std::string target_graph = R"IR(
37+
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
38+
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
39+
return (%out))IR";
40+
41+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
42+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
43+
auto sg = std::make_shared<torch::jit::Graph>();
44+
torch::jit::parseIR(source_graph, &*sg);
45+
torch_tensorrt::core::lowering::passes::ReduceToOperation(sg);
46+
47+
auto tg = std::make_shared<torch::jit::Graph>();
48+
torch::jit::parseIR(target_graph, &*tg);
49+
50+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
51+
}
52+
3153
TEST(LoweringPasses, ReduceAtenTypeAsCorrectly) {
3254
std::string source_graph = R"IR(
3355
graph(%input, %other):

0 commit comments

Comments
 (0)