1+ #include < string>
2+ #include " core/compiler.h"
3+ #include " core/lowering/passes/passes.h"
4+ #include " gtest/gtest.h"
5+ #include " tests/util/util.h"
6+ #include " torch/csrc/jit/ir/irparser.h"
7+ #include " torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+ TEST (LoweringPasses, Conv1dCorrectly) {
10+ const auto source_graph = R"IR(
11+ graph(%0 : Tensor,
12+ %1 : Float(4, 3, 3, strides=[9, 3, 1]),
13+ %2 : Float(3)):
14+ %4 : int = prim::Constant[value=0]()
15+ %5 : int = prim::Constant[value=1]()
16+ %6 : int = prim::Constant[value=1]()
17+ %stride : int[] = prim::ListConstruct(%6)
18+ %padding : int[] = prim::ListConstruct(%4)
19+ %dilation : int[] = prim::ListConstruct(%5)
20+ %12 : Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6)
21+ return (%12))IR" ;
22+
23+ const auto target_graph = R"IR(
24+ graph(%0 : Tensor,
25+ %1 : Float(4, 3, 3, strides=[9, 3, 1]),
26+ %2 : Float(3)):
27+ %3 : bool = prim::Constant[value=0]()
28+ %4 : int = prim::Constant[value=0]()
29+ %5 : int = prim::Constant[value=1]()
30+ %6 : int = prim::Constant[value=1]()
31+ %stride : int[] = prim::ListConstruct(%6)
32+ %padding : int[] = prim::ListConstruct(%4)
33+ %dilation : int[] = prim::ListConstruct(%5)
34+ %output_padding : int[] = prim::Constant[value=[0]]()
35+ %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
36+ return (%12))IR" ;
37+
38+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
39+ auto sg = std::make_shared<torch::jit::Graph>();
40+ torch::jit::parseIR (source_graph, &*sg);
41+ trtorch::core::lowering::passes::Conv1DToConvolution (sg);
42+
43+ auto tg = std::make_shared<torch::jit::Graph>();
44+ torch::jit::parseIR (target_graph, &*tg);
45+
46+ auto in = at::randint (1 , 2 , {1 , 3 , 3 }, {at::kCUDA });
47+ auto w = at::randint (1 , 2 , {4 , 3 , 3 }, {at::kCUDA });
48+ auto b = at::randint (1 , 10 , {4 }, {at::kCUDA });
49+
50+ auto trt_in = at::clone (in);
51+ auto trt_w = at::clone (w);
52+ auto trt_b = at::clone (b);
53+ auto params = trtorch::core::conversion::get_named_params (sg->inputs (), {trt_w, trt_b});
54+ auto trt_results_sg = trtorch::tests::util::RunGraphEngine (sg, params, {trt_in});
55+
56+ params = trtorch::core::conversion::get_named_params (tg->inputs (), {trt_w, trt_b});
57+ auto trt_results_tg = trtorch::tests::util::RunGraphEngine (tg, params, {trt_in});
58+
59+ ASSERT_TRUE (trtorch::tests::util::almostEqual (trt_results_sg[0 ], trt_results_tg[0 ], 2e-6 ));
60+ }
61+
62+ TEST (LoweringPasses, ConvTransposed1dCorrectly) {
63+ const auto source_graph = R"IR(
64+ graph(%0 : Tensor,
65+ %1 : Float(8, 3, 3, strides=[9, 3, 1]),
66+ %2 : Float(3)):
67+ %3 : int = prim::Constant[value=1]()
68+ %4 : int = prim::Constant[value=0]()
69+ %5 : int = prim::Constant[value=1]()
70+ %6 : int = prim::Constant[value=0]()
71+ %stride : int[] = prim::ListConstruct(%3)
72+ %padding : int[] = prim::ListConstruct(%4)
73+ %dilation : int[] = prim::ListConstruct(%5)
74+ %output_padding : int[] = prim::ListConstruct(%6)
75+ %12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %stride, %padding, %output_padding, %3, %dilation)
76+ return (%12))IR" ;
77+
78+ const auto target_graph = R"IR(
79+ graph(%0 : Tensor,
80+ %1 : Float(8, 3, 3, strides=[9, 3, 1]),
81+ %2 : Float(3)):
82+ %3 : int = prim::Constant[value=1]()
83+ %4 : int = prim::Constant[value=0]()
84+ %5 : int = prim::Constant[value=1]()
85+ %6 : int = prim::Constant[value=0]()
86+ %7 : bool = prim::Constant[value=0]()
87+ %8 : bool = prim::Constant[value=1]()
88+ %stride : int[] = prim::ListConstruct(%3)
89+ %padding : int[] = prim::ListConstruct(%4)
90+ %dilation : int[] = prim::ListConstruct(%5)
91+ %output_padding : int[] = prim::ListConstruct(%6)
92+ %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7)
93+ return (%12))IR" ;
94+
95+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
96+ auto sg = std::make_shared<torch::jit::Graph>();
97+ torch::jit::parseIR (source_graph, &*sg);
98+ trtorch::core::lowering::passes::ConvTransposed1DToConvolution (sg);
99+
100+ auto tg = std::make_shared<torch::jit::Graph>();
101+ torch::jit::parseIR (target_graph, &*tg);
102+
103+ auto in = at::randint (1 , 2 , {1 , 8 , 3 }, {at::kCUDA });
104+ auto w = at::randint (1 , 2 , {8 , 3 , 3 }, {at::kCUDA });
105+ auto b = at::randint (1 , 10 , {3 }, {at::kCUDA });
106+
107+ auto trt_in = at::clone (in);
108+ auto trt_w = at::clone (w);
109+ auto trt_b = at::clone (b);
110+ auto params = trtorch::core::conversion::get_named_params (sg->inputs (), {trt_w, trt_b});
111+ auto trt_results_sg = trtorch::tests::util::RunGraphEngine (sg, params, {trt_in});
112+
113+ params = trtorch::core::conversion::get_named_params (tg->inputs (), {trt_w, trt_b});
114+ auto trt_results_tg = trtorch::tests::util::RunGraphEngine (tg, params, {trt_in});
115+
116+ ASSERT_TRUE (trtorch::tests::util::almostEqual (trt_results_sg[0 ], trt_results_tg[0 ], 2e-6 ));
117+ }
0 commit comments