3
3
#include " core/compiler.h"
4
4
#include " core/util/trt_util.h"
5
5
#include " gtest/gtest.h"
6
+ #include " torch/csrc/jit/ir/constants.h"
7
+ #include " torch/csrc/jit/ir/irparser.h"
6
8
#include " torch/script.h"
7
9
8
10
bool checkAllInputsExistInStitchedGraph (std::shared_ptr<torch::jit::Graph> g) {
@@ -22,39 +24,117 @@ bool checkAllInputsExistInStitchedGraph(std::shared_ptr<torch::jit::Graph> g) {
22
24
return true ;
23
25
}
24
26
25
- TEST (Partitioning, StitchResNet50SegmentedBlockCorrectly) {
26
- torch::jit::script::Module mod;
27
- try {
28
- mod = torch::jit::load (" tests/modules/resnet50_traced.jit.pt" );
29
- } catch (const c10::Error& e) {
30
- std::cerr << " error loading the model\n " ;
31
- return ;
27
+ TEST (Partitioning, StitchSequentialModelSegmentedBlockCorrectly) {
28
+ const auto graph = R"IR(
29
+ graph(%0 : Tensor,
30
+ %w1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
31
+ %b1 : Float(32),
32
+ %w2 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
33
+ %b2 : Float(16),
34
+ %w3 : Float(8, 16, 3, 3, strides=[144, 9, 3, 1]),
35
+ %b3 : Float(8)):
36
+ %2 : int[] = prim::Constant[value=[1, 1]]()
37
+ %3 : int = prim::Constant[value=1]()
38
+ %10 : bool = prim::Constant[value=0]()
39
+ %11 : int[] = prim::Constant[value=[0, 0]]()
40
+ %12: Tensor = aten::_convolution(%0, %w1, %b1, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
41
+ %13 : Tensor = aten::relu(%12)
42
+ %14 : Tensor = aten::_convolution(%13, %w2, %b2, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
43
+ %15 : Tensor = aten::log_sigmoid(%14)
44
+ %16 : Tensor = aten::_convolution(%15, %w3, %b3, %2, %2, %2, %10, %11, %3, %10, %10, %10, %10)
45
+ return (%16))IR" ;
46
+
47
+ auto parsed_g = std::make_shared<torch::jit::Graph>();
48
+ torch::jit::parseIR (graph, parsed_g.get ());
49
+
50
+ auto g = std::make_shared<torch::jit::Graph>();
51
+ std::vector<std::vector<int64_t >> all_shapes{{32 , 3 , 3 , 3 }, {32 }, {16 , 32 , 3 , 3 }, {16 }, {8 , 16 , 3 , 3 }, {8 }};
52
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
53
+ for (size_t i = 0 ; i < all_shapes.size (); ++i) {
54
+ auto in = at::randint (5 , all_shapes[i], {at::kCUDA });
55
+ torch::jit::IValue cur_val = in.clone ();
56
+ auto new_val = g->insertConstant (cur_val);
57
+ tensor_to_constant[parsed_g->inputs ()[i + 1 ]] = new_val;
58
+ }
59
+ for (auto node : parsed_g->nodes ()) {
60
+ if (node->kind () == torch::jit::prim::Constant)
61
+ continue ;
62
+ trtorch::core::util::cloneNode (node, g, tensor_to_constant);
32
63
}
64
+ g->registerOutput (tensor_to_constant[parsed_g->outputs ()[0 ]]);
33
65
34
- std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange ({1 , 3 , 224 , 224 })};
66
+ std::vector<trtorch::core::ir::InputRange> input_ranges;
67
+ input_ranges.push_back (trtorch::core::ir::InputRange ({3 , 3 , 16 , 16 }));
35
68
trtorch::core::CompileSpec cfg (input_ranges);
36
69
cfg.partition_info .enabled = true ;
37
- cfg.partition_info .forced_fallback_operators .push_back (" aten::add" );
70
+ torch::jit::script::Module mod (c10::QualifiedName (" module" ));
71
+
72
+ auto self = g->insertInput (0 , " self_1" );
73
+ self->setType (mod.type ());
74
+ auto cur_method = mod._ivalue ()->compilation_unit ()->create_function (c10::QualifiedName (" forward" ), g);
75
+ auto schema = trtorch::core::util::GenerateGraphSchema (cur_method->name (), g);
76
+ mod.type ()->addMethod (cur_method);
77
+ cur_method->setSchema (schema);
78
+
38
79
torch::jit::script::Module new_mod = trtorch::core::CompileGraph (mod, cfg);
39
- auto g = new_mod.get_method (" forward" ).graph ();
40
- ASSERT_TRUE (checkAllInputsExistInStitchedGraph (g ));
80
+ auto fallback_g = new_mod.get_method (" forward" ).graph ();
81
+ ASSERT_TRUE (checkAllInputsExistInStitchedGraph (fallback_g ));
41
82
}
42
83
43
- TEST (Partitioning, StitchMobileNetSegmentedBlockCorrectlyEdge) {
44
- torch::jit::script::Module mod;
45
- try {
46
- mod = torch::jit::load (" tests/modules/mobilenet_v2_traced.jit.pt" );
47
- } catch (const c10::Error& e) {
48
- std::cerr << " error loading the model\n " ;
49
- return ;
84
+ TEST (Partitioning, StitchBranchModelSegmentedBlockCorrectly) {
85
+ const auto graph = R"IR(
86
+ graph(%0 : Tensor,
87
+ %1 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1]),
88
+ %2 : Float(32),
89
+ %3 : Float(16, 32, 3, 3, strides=[288, 9, 3, 1]),
90
+ %4 : Float(16)):
91
+ %5 : int[] = prim::Constant[value=[0, 0]]()
92
+ %6 : int[] = prim::Constant[value=[2, 2]]()
93
+ %7 : bool = prim::Constant[value=0]()
94
+ %8 : int[] = prim::Constant[value=[1, 1]]()
95
+ %9 : int = prim::Constant[value=1]()
96
+ %10: Tensor = aten::_convolution(%0, %1, %2, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
97
+ %11 : Tensor = aten::_convolution(%10, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
98
+ %12: Tensor = aten::log_sigmoid(%10)
99
+ %13 : Tensor = aten::_convolution(%12, %3, %4, %8, %8, %8, %7, %5, %9, %7, %7, %7, %7)
100
+ %14 : Tensor = aten::relu(%11)
101
+ %15 : Tensor = aten::add(%13, %14, %9)
102
+ %16 : Tensor = aten::max_pool2d(%15, %6, %6, %5, %8, %7)
103
+ return (%16))IR" ;
104
+
105
+ auto parsed_g = std::make_shared<torch::jit::Graph>();
106
+ torch::jit::parseIR (graph, parsed_g.get ());
107
+
108
+ auto g = std::make_shared<torch::jit::Graph>();
109
+ std::vector<std::vector<int64_t >> all_shapes{{32 , 3 , 3 , 3 }, {32 }, {16 , 32 , 3 , 3 }, {16 }};
110
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> tensor_to_constant;
111
+ for (size_t i = 0 ; i < all_shapes.size (); ++i) {
112
+ auto in = at::randint (5 , all_shapes[i], {at::kCUDA });
113
+ torch::jit::IValue cur_val = in.clone ();
114
+ auto new_val = g->insertConstant (cur_val);
115
+ tensor_to_constant[parsed_g->inputs ()[i + 1 ]] = new_val;
50
116
}
117
+ for (auto node : parsed_g->nodes ()) {
118
+ if (node->kind () == torch::jit::prim::Constant)
119
+ continue ;
120
+ trtorch::core::util::cloneNode (node, g, tensor_to_constant);
121
+ }
122
+ g->registerOutput (tensor_to_constant[parsed_g->outputs ()[0 ]]);
51
123
52
- std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange ({1 , 3 , 224 , 224 })};
124
+ std::vector<trtorch::core::ir::InputRange> input_ranges;
125
+ input_ranges.push_back (trtorch::core::ir::InputRange ({3 , 3 , 16 , 16 }));
53
126
trtorch::core::CompileSpec cfg (input_ranges);
54
127
cfg.partition_info .enabled = true ;
55
- cfg.partition_info .forced_fallback_operators .push_back (" aten::hardtanh" );
128
+ torch::jit::script::Module mod (c10::QualifiedName (" module" ));
129
+
130
+ auto self = g->insertInput (0 , " self_1" );
131
+ self->setType (mod.type ());
132
+ auto cur_method = mod._ivalue ()->compilation_unit ()->create_function (c10::QualifiedName (" forward" ), g);
133
+ auto schema = trtorch::core::util::GenerateGraphSchema (cur_method->name (), g);
134
+ mod.type ()->addMethod (cur_method);
135
+ cur_method->setSchema (schema);
56
136
57
137
torch::jit::script::Module new_mod = trtorch::core::CompileGraph (mod, cfg);
58
- auto g = new_mod.get_method (" forward" ).graph ();
59
- ASSERT_TRUE (checkAllInputsExistInStitchedGraph (g ));
138
+ auto fallback_g = new_mod.get_method (" forward" ).graph ();
139
+ ASSERT_TRUE (checkAllInputsExistInStitchedGraph (fallback_g ));
60
140
}
0 commit comments