@@ -71,47 +71,29 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
71
71
72
72
bool CheckMethodOperatorSupport (const torch::jit::script::Module& mod,
73
73
std::string method_name) {
74
- auto g = mod.get_method (method_name).graph ();
75
- // Go through PyTorch Lowering to simplify graph and extract weight parameters
76
- auto graph_and_parameters = torch::jit::LowerGraph (*g, mod._ivalue ());
74
+ // Go through Lowering to simplify graph and extract weight parameters
75
+ auto graph_and_parameters = lowering::Lower (mod, method_name);
77
76
78
- g = graph_and_parameters.first ;
79
-
80
- // Go through TRTorch Lowering to reformat graph to be conversion friendly
81
- // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
82
- lowering::LowerGraph (g);
83
-
84
- auto params = graph_and_parameters.second ;
85
- auto named_params = conversion::get_named_params (g->inputs (), params);
77
+ auto g = graph_and_parameters.first ;
86
78
LOG_DEBUG (*g << " (CheckMethodOperatorSupport)\n " );
87
79
88
- // Is this necessary?
89
- lowering::LowerBlock (g->block ());
90
-
91
80
return conversion::VerifyConverterSupportForBlock (g->block ());
92
81
}
93
82
94
83
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod,
95
84
std::string method_name,
96
85
ExtraInfo cfg) {
97
- auto convert_cfg = std::move (cfg.convert_info );
98
-
99
- auto g = mod.get_method (method_name).graph ();
100
- // Go through PyTorch Lowering to simplify graph and extract weight parameters
101
- auto graph_and_parameters = torch::jit::LowerGraph (*g, mod._ivalue ());
102
-
103
- g = graph_and_parameters.first ;
104
86
105
- // Go through TRTorch Lowering to reformat graph to be conversion friendly
106
- // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
107
- lowering::LowerGraph (g);
87
+ // Go through Lowering to simplify graph and extract weight parameters
88
+ auto graph_and_parameters = lowering::Lower (mod, method_name);
108
89
90
+ auto convert_cfg = std::move (cfg.convert_info );
91
+ auto g = graph_and_parameters.first ;
109
92
auto params = graph_and_parameters.second ;
110
93
auto named_params = conversion::get_named_params (g->inputs (), params);
94
+
111
95
LOG_INFO (*g << " (CompileGraph)\n " );
112
96
113
- // Is this necessary?
114
- lowering::LowerBlock (g->block ());
115
97
auto engine = ConvertBlockToEngine (g->block (), convert_cfg, named_params);
116
98
return std::move (engine);
117
99
}
0 commit comments