Skip to content

Commit 79c8259

Browse files
committed
feat(//cpp): Fixed the failure for custom library loading
Signed-off-by: Anurag Dixit <[email protected]>
1 parent baa5659 commit 79c8259

File tree

1 file changed

+51
-36
lines changed

1 file changed

+51
-36
lines changed

cpp/bin/torchtrtc/main.cpp

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,25 @@
2121
#include <dlfcn.h>
2222
#endif
2323

24-
void load_library(std::string& plugin, std::string option, void* handle) {
24+
void* load_library(std::string& custom_lib) {
25+
void* handle = {nullptr};
2526
#if defined(_WIN32)
26-
handle = LoadLibrary(plugin.c_str());
27+
handle = LoadLibrary(custom_lib.c_str());
2728
#else
28-
handle = dlopen(plugin.c_str(), RTLD_LAZY);
29+
handle = dlopen(custom_lib.c_str(), RTLD_LAZY);
2930
#endif
30-
if (handle == nullptr) {
31-
torchtrt::logging::log(
32-
torchtrt::logging::Level::kERROR, std::string("Could not load custom library " + plugin + " for " + option));
33-
} else {
34-
torchtrt::logging::log(
35-
torchtrt::logging::Level::kINFO, std::string("Loaded custom library " + plugin + " for " + option));
36-
}
31+
return handle;
3732
}
3833

39-
void unload_library(void* custom_lib, std::string& name) {
34+
bool unload_library(void* custom_lib) {
35+
bool success = false;
4036
#if defined(_WIN32)
41-
auto status = FreeLibrary(custom_lib);
42-
// Return status non-zero for success
43-
if (status) {
44-
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name));
45-
} else {
46-
torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name));
47-
}
37+
// Returns status non-zero for success
38+
success = FreeLibrary(custom_lib) ? true : false;
4839
#else
49-
auto status = dlclose(custom_lib);
50-
// Return status 0 for success
51-
if (!status) {
52-
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name));
53-
} else {
54-
torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name));
55-
}
40+
success = dlclose(custom_lib) ? false : true;
5641
#endif
42+
return success;
5743
}
5844

5945
int main(int argc, char** argv) {
@@ -188,10 +174,16 @@ int main(int argc, char** argv) {
188174
"Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path",
189175
{"save-engine"});
190176
args::ValueFlagList<std::string> custom_torch_ops(
191-
parser, "custom-torch-ops", "Shared object/DLL containing custom torch operator", {"custom-torch-ops"});
177+
parser,
178+
"custom-torch-ops",
179+
"(repeatable) Shared object/DLL containing custom torch operator",
180+
{"custom-torch-ops"});
192181

193182
args::ValueFlagList<std::string> custom_converters(
194-
parser, "custom-converters", "Shared object/DLL containing custom converters", {"custom-converters"});
183+
parser,
184+
"custom-converters",
185+
"(repeatable) Shared object/DLL containing custom converters",
186+
{"custom-converters"});
195187

196188
args::Positional<std::string> input_path(parser, "input_file_path", "Path to input TorchScript file");
197189
args::Positional<std::string> output_path(
@@ -223,17 +215,28 @@ int main(int argc, char** argv) {
223215
std::vector<std::pair<std::string, void*>> custom_torch_op, custom_converter_op;
224216
if (custom_torch_ops) {
225217
for (auto& op : args::get(custom_torch_ops)) {
226-
void* handle{nullptr};
227-
load_library(op, "custom_torch_ops", handle);
228-
custom_torch_op.push_back({op, handle});
218+
auto* handle = load_library(op);
219+
if (handle == nullptr) {
220+
torchtrt::logging::log(
221+
torchtrt::logging::Level::kERROR, std::string("Could not load custom_torch_ops library " + op));
222+
} else {
223+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_torch_ops library " + op));
224+
225+
custom_torch_op.push_back({op, handle});
226+
}
229227
}
230228
}
231229

232230
if (custom_converters) {
233231
for (auto& op : args::get(custom_converters)) {
234-
void* handle{nullptr};
235-
load_library(op, "custom_converters", handle);
236-
custom_converter_op.push_back({op, handle});
232+
auto* handle = load_library(op);
233+
if (handle == nullptr) {
234+
torchtrt::logging::log(
235+
torchtrt::logging::Level::kERROR, std::string("Could not load custom_converter library " + op));
236+
} else {
237+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_converter library " + op));
238+
custom_converter_op.push_back({op, handle});
239+
}
237240
}
238241
}
239242

@@ -252,7 +255,7 @@ int main(int argc, char** argv) {
252255
auto method = args::get(check_method_op_support);
253256
auto result = torchtrt::ts::check_method_operator_support(mod, method);
254257
if (result) {
255-
std::cout << "The method is supported end to end by Torch-TensorRT" << std::endl;
258+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, "The method is supported end to end by Torch-TensorRT");
256259
return 0;
257260
} else {
258261
torchtrt::logging::log(torchtrt::logging::Level::kERROR, "Method is not currently supported by Torch-TensorRT");
@@ -542,13 +545,25 @@ int main(int argc, char** argv) {
542545

543546
if (custom_torch_ops) {
544547
for (auto& p : custom_torch_op) {
545-
unload_library(p.second, p.first);
548+
auto status = unload_library(p.second);
549+
if (status) {
550+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first));
551+
} else {
552+
torchtrt::logging::log(
553+
torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first));
554+
}
546555
}
547556
}
548557

549558
if (custom_converters) {
550559
for (auto& p : custom_converter_op) {
551-
unload_library(p.second, p.first);
560+
auto status = unload_library(p.second);
561+
if (status) {
562+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first));
563+
} else {
564+
torchtrt::logging::log(
565+
torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first));
566+
}
552567
}
553568
}
554569

0 commit comments

Comments
 (0)