|
15 | 15 | #include "luts.h"
|
16 | 16 | #include "parser_util.h"
|
17 | 17 |
|
| 18 | +#if defined(_WIN32) |
| 19 | +#include <windows.h> |
| 20 | +#else |
| 21 | +#include <dlfcn.h> |
| 22 | +#endif |
| 23 | + |
| 24 | +void load_library(std::string& plugin, std::string option, void* handle) { |
| 25 | +#if defined(_WIN32) |
| 26 | + handle = LoadLibrary(plugin.c_str()); |
| 27 | +#else |
| 28 | + handle = dlopen(plugin.c_str(), RTLD_LAZY); |
| 29 | +#endif |
| 30 | + if (handle == nullptr) { |
| 31 | + torchtrt::logging::log( |
| 32 | + torchtrt::logging::Level::kERROR, std::string("Could not load custom library " + plugin + " for " + option)); |
| 33 | + } else { |
| 34 | + torchtrt::logging::log( |
| 35 | + torchtrt::logging::Level::kINFO, std::string("Loaded custom library " + plugin + " for " + option)); |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +void unload_library(void* custom_lib, std::string& name) { |
| 40 | +#if defined(_WIN32) |
| 41 | + auto status = FreeLibrary(custom_lib); |
| 42 | + // Return status non-zero for success |
| 43 | + if (status) { |
| 44 | + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name)); |
| 45 | + } else { |
| 46 | + torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name)); |
| 47 | + } |
| 48 | +#else |
| 49 | + auto status = dlclose(custom_lib); |
| 50 | + // Return status 0 for success |
| 51 | + if (!status) { |
| 52 | + torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + name)); |
| 53 | + } else { |
| 54 | + torchtrt::logging::log(torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + name)); |
| 55 | + } |
| 56 | +#endif |
| 57 | +} |
| 58 | + |
18 | 59 | int main(int argc, char** argv) {
|
19 | 60 | torchtrt::logging::set_is_colored_output_on(true);
|
20 | 61 | torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kWARNING);
|
@@ -117,8 +158,7 @@ int main(int argc, char** argv) {
|
117 | 158 | parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
|
118 | 159 | args::ValueFlag<uint64_t> workspace_size(
|
119 | 160 | parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
|
120 |
| - args::ValueFlag<uint64_t> dla_sram_size( |
121 |
| - parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"}); |
| 161 | + args::ValueFlag<uint64_t> dla_sram_size(parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"}); |
122 | 162 | args::ValueFlag<uint64_t> dla_local_dram_size(
|
123 | 163 | parser, "dla_local_dram_size", "DLA Local DRAM size", {"dla-local-dram-size"});
|
124 | 164 | args::ValueFlag<uint64_t> dla_global_dram_size(
|
@@ -147,6 +187,12 @@ int main(int argc, char** argv) {
|
147 | 187 | "save_engine",
|
148 | 188 | "Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path",
|
149 | 189 | {"save-engine"});
|
| 190 | + args::ValueFlagList<std::string> custom_torch_ops( |
| 191 | + parser, "custom-torch-ops", "Shared object/DLL containing custom torch operator", {"custom-torch-ops"}); |
| 192 | + |
| 193 | + args::ValueFlagList<std::string> custom_converters( |
| 194 | + parser, "custom-converters", "Shared object/DLL containing custom converters", {"custom-converters"}); |
| 195 | + |
150 | 196 | args::Positional<std::string> input_path(parser, "input_file_path", "Path to input TorchScript file");
|
151 | 197 | args::Positional<std::string> output_path(
|
152 | 198 | parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file");
|
@@ -174,6 +220,23 @@ int main(int argc, char** argv) {
|
174 | 220 | torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kERROR);
|
175 | 221 | }
|
176 | 222 |
|
| 223 | + std::vector<std::pair<std::string, void*>> custom_torch_op, custom_converter_op; |
| 224 | + if (custom_torch_ops) { |
| 225 | + for (auto& op : args::get(custom_torch_ops)) { |
| 226 | + void* handle{nullptr}; |
| 227 | + load_library(op, "custom_torch_ops", handle); |
| 228 | + custom_torch_op.push_back({op, handle}); |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + if (custom_converters) { |
| 233 | + for (auto& op : args::get(custom_converters)) { |
| 234 | + void* handle{nullptr}; |
| 235 | + load_library(op, "custom_converters", handle); |
| 236 | + custom_converter_op.push_back({op, handle}); |
| 237 | + } |
| 238 | + } |
| 239 | + |
177 | 240 | auto real_input_path = torchtrtc::fileio::resolve_path(args::get(input_path));
|
178 | 241 |
|
179 | 242 | if (check_method_op_support) {
|
@@ -477,5 +540,17 @@ int main(int argc, char** argv) {
|
477 | 540 | trt_mod.save(real_output_path);
|
478 | 541 | }
|
479 | 542 |
|
| 543 | + if (custom_torch_ops) { |
| 544 | + for (auto& p : custom_torch_op) { |
| 545 | + unload_library(p.second, p.first); |
| 546 | + } |
| 547 | + } |
| 548 | + |
| 549 | + if (custom_converters) { |
| 550 | + for (auto& p : custom_converter_op) { |
| 551 | + unload_library(p.second, p.first); |
| 552 | + } |
| 553 | + } |
| 554 | + |
480 | 555 | return 0;
|
481 | 556 | }
|
0 commit comments