21
21
#include < dlfcn.h>
22
22
#endif
23
23
24
- void load_library (std::string& plugin, std::string option, void * handle) {
24
+ void * load_library (std::string& custom_lib) {
25
+ void * handle = {nullptr };
25
26
#if defined(_WIN32)
26
- handle = LoadLibrary (plugin .c_str ());
27
+ handle = LoadLibrary (custom_lib .c_str ());
27
28
#else
28
- handle = dlopen (plugin .c_str (), RTLD_LAZY);
29
+ handle = dlopen (custom_lib .c_str (), RTLD_LAZY);
29
30
#endif
30
- if (handle == nullptr ) {
31
- torchtrt::logging::log (
32
- torchtrt::logging::Level::kERROR , std::string (" Could not load custom library " + plugin + " for " + option));
33
- } else {
34
- torchtrt::logging::log (
35
- torchtrt::logging::Level::kINFO , std::string (" Loaded custom library " + plugin + " for " + option));
36
- }
31
+ return handle;
37
32
}
38
33
39
- void unload_library (void * custom_lib, std::string& name) {
34
+ bool unload_library (void * custom_lib) {
35
+ bool success = false ;
40
36
#if defined(_WIN32)
41
- auto status = FreeLibrary (custom_lib);
42
- // Return status non-zero for success
43
- if (status) {
44
- torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Unloaded custom library " + name));
45
- } else {
46
- torchtrt::logging::log (torchtrt::logging::Level::kERROR , std::string (" Could not unload custom library " + name));
47
- }
37
+ // Returns status non-zero for success
38
+ success = FreeLibrary (custom_lib) ? true : false ;
48
39
#else
49
- auto status = dlclose (custom_lib);
50
- // Return status 0 for success
51
- if (!status) {
52
- torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Unloaded custom library " + name));
53
- } else {
54
- torchtrt::logging::log (torchtrt::logging::Level::kERROR , std::string (" Could not unload custom library " + name));
55
- }
40
+ success = dlclose (custom_lib) ? false : true ;
56
41
#endif
42
+ return success;
57
43
}
58
44
59
45
int main (int argc, char ** argv) {
@@ -188,10 +174,16 @@ int main(int argc, char** argv) {
188
174
" Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path" ,
189
175
{" save-engine" });
190
176
args::ValueFlagList<std::string> custom_torch_ops (
191
- parser, " custom-torch-ops" , " Shared object/DLL containing custom torch operator" , {" custom-torch-ops" });
177
+ parser,
178
+ " custom-torch-ops" ,
179
+ " (repeatable) Shared object/DLL containing custom torch operator" ,
180
+ {" custom-torch-ops" });
192
181
193
182
args::ValueFlagList<std::string> custom_converters (
194
- parser, " custom-converters" , " Shared object/DLL containing custom converters" , {" custom-converters" });
183
+ parser,
184
+ " custom-converters" ,
185
+ " (repeatable) Shared object/DLL containing custom converters" ,
186
+ {" custom-converters" });
195
187
196
188
args::Positional<std::string> input_path (parser, " input_file_path" , " Path to input TorchScript file" );
197
189
args::Positional<std::string> output_path (
@@ -223,17 +215,28 @@ int main(int argc, char** argv) {
223
215
std::vector<std::pair<std::string, void *>> custom_torch_op, custom_converter_op;
224
216
if (custom_torch_ops) {
225
217
for (auto & op : args::get (custom_torch_ops)) {
226
- void * handle{nullptr };
227
- load_library (op, " custom_torch_ops" , handle);
228
- custom_torch_op.push_back ({op, handle});
218
+ auto * handle = load_library (op);
219
+ if (handle == nullptr ) {
220
+ torchtrt::logging::log (
221
+ torchtrt::logging::Level::kERROR , std::string (" Could not load custom_torch_ops library " + op));
222
+ } else {
223
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Loaded custom_torch_ops library " + op));
224
+
225
+ custom_torch_op.push_back ({op, handle});
226
+ }
229
227
}
230
228
}
231
229
232
230
if (custom_converters) {
233
231
for (auto & op : args::get (custom_converters)) {
234
- void * handle{nullptr };
235
- load_library (op, " custom_converters" , handle);
236
- custom_converter_op.push_back ({op, handle});
232
+ auto * handle = load_library (op);
233
+ if (handle == nullptr ) {
234
+ torchtrt::logging::log (
235
+ torchtrt::logging::Level::kERROR , std::string (" Could not load custom_converter library " + op));
236
+ } else {
237
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Loaded custom_converter library " + op));
238
+ custom_converter_op.push_back ({op, handle});
239
+ }
237
240
}
238
241
}
239
242
@@ -252,7 +255,7 @@ int main(int argc, char** argv) {
252
255
auto method = args::get (check_method_op_support);
253
256
auto result = torchtrt::ts::check_method_operator_support (mod, method);
254
257
if (result) {
255
- std::cout << " The method is supported end to end by Torch-TensorRT" << std::endl ;
258
+ torchtrt::logging::log (torchtrt::logging::Level:: kINFO , " The method is supported end to end by Torch-TensorRT" ) ;
256
259
return 0 ;
257
260
} else {
258
261
torchtrt::logging::log (torchtrt::logging::Level::kERROR , " Method is not currently supported by Torch-TensorRT" );
@@ -542,13 +545,25 @@ int main(int argc, char** argv) {
542
545
543
546
if (custom_torch_ops) {
544
547
for (auto & p : custom_torch_op) {
545
- unload_library (p.second , p.first );
548
+ auto status = unload_library (p.second );
549
+ if (status) {
550
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Unloaded custom library " + p.first ));
551
+ } else {
552
+ torchtrt::logging::log (
553
+ torchtrt::logging::Level::kERROR , std::string (" Could not unload custom library " + p.first ));
554
+ }
546
555
}
547
556
}
548
557
549
558
if (custom_converters) {
550
559
for (auto & p : custom_converter_op) {
551
- unload_library (p.second , p.first );
560
+ auto status = unload_library (p.second );
561
+ if (status) {
562
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Unloaded custom library " + p.first ));
563
+ } else {
564
+ torchtrt::logging::log (
565
+ torchtrt::logging::Level::kERROR , std::string (" Could not unload custom library " + p.first ));
566
+ }
552
567
}
553
568
}
554
569
0 commit comments