Skip to content

Commit baa5659

Browse files
committed
feat(cpp): Added support for loading runtime custom torch op and custom converters in torchtrtc
Signed-off-by: Anurag Dixit <[email protected]>
1 parent db61e90 commit baa5659

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

cpp/bin/torchtrtc/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ cc_binary(
1919
"parser_util.h",
2020
"parser_util.cpp"
2121
],
22+
linkopts = [
23+
"-l:libdl.so"
24+
],
2225
deps = [
2326
"//third_party/args",
2427
"//cpp:torch_tensorrt",

cpp/bin/torchtrtc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ add_executable(${executable_name}
1010
if (MSVC)
1111
target_link_libraries(${executable_name} PRIVATE torch torchtrt)
1212
else()
13-
target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed" torchtrt "-Wl,--as-needed")
13+
target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed -ldl" torchtrt "-Wl,--as-needed")
1414
set_target_properties(
1515
${executable_name}
1616
PROPERTIES INSTALL_RPATH_USE_LINK_PATH FALSE #

cpp/bin/torchtrtc/main.cpp

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,47 @@
1515
#include "luts.h"
1616
#include "parser_util.h"
1717

18+
#if defined(_WIN32)
19+
#include <windows.h>
20+
#else
21+
#include <dlfcn.h>
22+
#endif
23+
24+
void load_library(std::string& plugin, std::string option, void* handle) {
25+
#if defined(_WIN32)
26+
handle = LoadLibrary(plugin.c_str());
27+
#else
28+
handle = dlopen(plugin.c_str(), RTLD_LAZY);
29+
#endif
30+
if (handle == nullptr) {
31+
torchtrt::logging::log(
32+
torchtrt::logging::Level::kERROR, std::string("Could not load custom library " + plugin + " for " + option));
33+
} else {
34+
torchtrt::logging::log(
35+
torchtrt::logging::Level::kINFO, std::string("Loaded custom library " + plugin + " for " + option));
36+
}
37+
}
38+
39+
void unload_library(void* custom_lib, std::string& name) {
40+
#if defined(_WIN32)
41+
auto status = FreeLibrary(custom_lib);
42+
// Return status non-zero for success
43+
if (status) {
44+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name));
45+
} else {
46+
torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name));
47+
}
48+
#else
49+
auto status = dlclose(custom_lib);
50+
// Return status 0 for success
51+
if (!status) {
52+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name));
53+
} else {
54+
torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name));
55+
}
56+
#endif
57+
}
58+
1859
int main(int argc, char** argv) {
1960
torchtrt::logging::set_is_colored_output_on(true);
2061
torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kWARNING);
@@ -117,8 +158,7 @@ int main(int argc, char** argv) {
117158
parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
118159
args::ValueFlag<uint64_t> workspace_size(
119160
parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
120-
args::ValueFlag<uint64_t> dla_sram_size(
121-
parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"});
161+
args::ValueFlag<uint64_t> dla_sram_size(parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"});
122162
args::ValueFlag<uint64_t> dla_local_dram_size(
123163
parser, "dla_local_dram_size", "DLA Local DRAM size", {"dla-local-dram-size"});
124164
args::ValueFlag<uint64_t> dla_global_dram_size(
@@ -147,6 +187,12 @@ int main(int argc, char** argv) {
147187
"save_engine",
148188
"Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path",
149189
{"save-engine"});
190+
args::ValueFlagList<std::string> custom_torch_ops(
191+
parser, "custom-torch-ops", "Shared object/DLL containing custom torch operator", {"custom-torch-ops"});
192+
193+
args::ValueFlagList<std::string> custom_converters(
194+
parser, "custom-converters", "Shared object/DLL containing custom converters", {"custom-converters"});
195+
150196
args::Positional<std::string> input_path(parser, "input_file_path", "Path to input TorchScript file");
151197
args::Positional<std::string> output_path(
152198
parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file");
@@ -174,6 +220,23 @@ int main(int argc, char** argv) {
174220
torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kERROR);
175221
}
176222

223+
std::vector<std::pair<std::string, void*>> custom_torch_op, custom_converter_op;
224+
if (custom_torch_ops) {
225+
for (auto& op : args::get(custom_torch_ops)) {
226+
void* handle{nullptr};
227+
load_library(op, "custom_torch_ops", handle);
228+
custom_torch_op.push_back({op, handle});
229+
}
230+
}
231+
232+
if (custom_converters) {
233+
for (auto& op : args::get(custom_converters)) {
234+
void* handle{nullptr};
235+
load_library(op, "custom_converters", handle);
236+
custom_converter_op.push_back({op, handle});
237+
}
238+
}
239+
177240
auto real_input_path = torchtrtc::fileio::resolve_path(args::get(input_path));
178241

179242
if (check_method_op_support) {
@@ -477,5 +540,17 @@ int main(int argc, char** argv) {
477540
trt_mod.save(real_output_path);
478541
}
479542

543+
if (custom_torch_ops) {
544+
for (auto& p : custom_torch_op) {
545+
unload_library(p.second, p.first);
546+
}
547+
}
548+
549+
if (custom_converters) {
550+
for (auto& p : custom_converter_op) {
551+
unload_library(p.second, p.first);
552+
}
553+
}
554+
480555
return 0;
481556
}

0 commit comments

Comments
 (0)