Skip to content

Commit de3ba23

Browse files
committed
merge: resolve the confilct in AddEngineToGraph argument
Signed-off-by: Bo Wang <[email protected]>
2 parents c1934c1 + 72cb449 commit de3ba23

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+997
-145
lines changed

.github/scripts/run_cpp_linter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,7 @@
2626

2727
pr.create_review(commit, comment, approval)
2828

29-
29+
if output.returncode != 0:
30+
exit(1)
31+
else:
32+
exit(0)

.github/scripts/run_py_linter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@
2525
approval = 'REQUEST_CHANGES'
2626

2727
pr.create_review(commit, comment, approval)
28+
29+
if output.returncode != 0:
30+
exit(1)
31+
else:
32+
exit(0)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
7474
These are the following dependencies used to verify the testcases. TRTorch can work with other versions, but the tests are not guaranteed to pass.
7575

7676
- Bazel 4.0.0
77-
- Libtorch 1.8.0 (built with CUDA 11.1)
77+
- Libtorch 1.8.1 (built with CUDA 11.1)
7878
- CUDA 11.1 (10.2 on Jetson)
7979
- cuDNN 8.1
8080
- TensorRT 7.2.3

WORKSPACE

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,21 @@ workspace(name = "TRTorch")
33
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
44
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
55

6-
git_repository(
7-
name = "rules_python",
8-
remote = "https://github.com/bazelbuild/rules_python.git",
9-
commit = "4fcc24fd8a850bdab2ef2e078b1de337eea751a6",
10-
shallow_since = "1589292086 -0400"
11-
)
12-
13-
load("@rules_python//python:repositories.bzl", "py_repositories")
14-
py_repositories()
6+
http_archive(
7+
name = "rules_python",
8+
url = "https://github.com/bazelbuild/rules_python/releases/download/0.2.0/rules_python-0.2.0.tar.gz",
9+
sha256 = "778197e26c5fbeb07ac2a2c5ae405b30f6cb7ad1f5510ea6fdac03bded96cc6f",
10+
)
1511

16-
load("@rules_python//python:pip.bzl", "pip_repositories", "pip3_import")
17-
pip_repositories()
12+
load("@rules_python//python:pip.bzl", "pip_install")
1813

1914
http_archive(
2015
name = "rules_pkg",
21-
url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.4/rules_pkg-0.2.4.tar.gz",
22-
sha256 = "4ba8f4ab0ff85f2484287ab06c0d871dcb31cc54d439457d28fd4ae14b18450a",
16+
urls = [
17+
"https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
18+
"https://github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
19+
],
20+
sha256 = "038f1caa773a7e35b3663865ffb003169c6a71dc995e39bf4815792f385d837d",
2321
)
2422

2523
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
@@ -39,12 +37,6 @@ new_local_repository(
3937
build_file = "@//third_party/cuda:BUILD",
4038
)
4139

42-
new_local_repository(
43-
name = "cublas",
44-
path = "/usr",
45-
build_file = "@//third_party/cublas:BUILD",
46-
)
47-
4840
#############################################################################################################
4941
# Tarballs and fetched dependencies (default - use in cases when building from precompiled bin and tarballs)
5042
#############################################################################################################
@@ -53,16 +45,16 @@ http_archive(
5345
name = "libtorch",
5446
build_file = "@//third_party/libtorch:BUILD",
5547
strip_prefix = "libtorch",
56-
sha256 = "62a2c06761c32576b30f5884240cf675b937945d929e4b13cc776de8d9c2236c",
57-
urls = ["https://download.pytorch.org/libtorch/cu111/libtorch-cxx11-abi-shared-with-deps-1.8.0%2Bcu111.zip"],
48+
sha256 = "1f8aec376f9343538bd7c2fd3abb81ed3af11f575efe3aa72777c4d62044b832",
49+
urls = ["https://download.pytorch.org/libtorch/cu111/libtorch-cxx11-abi-shared-with-deps-1.8.1%2Bcu111.zip"],
5850
)
5951

6052
http_archive(
6153
name = "libtorch_pre_cxx11_abi",
6254
build_file = "@//third_party/libtorch:BUILD",
6355
strip_prefix = "libtorch",
64-
sha256 = "1c8b0c0883dd17f5ce952d42ec5f7f0cc7ceb370307535cee26a66c10419f1f6",
65-
urls = ["https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.0%2Bcu111.zip"],
56+
sha256 = "3a6e0dc11859111e75caa640c8ce9bf904fbb6e9992b4345e444ed5410e4d77e",
57+
urls = ["https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip"],
6658
)
6759

6860
# Download these tarballs manually from the NVIDIA website
@@ -71,15 +63,19 @@ http_archive(
7163

7264
http_archive(
7365
name = "cudnn",
74-
urls = ["https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.1.1.33/11.2_20210301/cudnn-11.2-linux-x64-v8.1.1.33.tgz",],
66+
urls = [
67+
"https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.1.1.33/11.2_20210301/cudnn-11.2-linux-x64-v8.1.1.33.tgz",
68+
],
7569
build_file = "@//third_party/cudnn/archive:BUILD",
7670
sha256 = "98a8784e92862f20018d20c281b30d4a0cd951f93694f6433ccf4ae9c502ba6a",
7771
strip_prefix = "cuda"
7872
)
7973

8074
http_archive(
8175
name = "tensorrt",
82-
urls = ["https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.3/tars/TensorRT-7.2.3.4.Ubuntu-18.04.x86_64-gnu.cuda-11.1.cudnn8.1.tar.gz",],
76+
urls = [
77+
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.3/tars/TensorRT-7.2.3.4.Ubuntu-18.04.x86_64-gnu.cuda-11.1.cudnn8.1.tar.gz",
78+
],
8379
build_file = "@//third_party/tensorrt/archive:BUILD",
8480
strip_prefix = "TensorRT-7.2.3.4",
8581
sha256 = "d3a1f478e304b48878604fac70ce7920fece71f9cac62f925c9c59c197f5d087"
@@ -123,26 +119,17 @@ http_archive(
123119
#########################################################################
124120
# Testing Dependencies (optional - comment out on aarch64)
125121
#########################################################################
126-
pip3_import(
122+
pip_install(
127123
name = "trtorch_py_deps",
128-
requirements = "//py:requirements.txt"
124+
requirements = "//py:requirements.txt",
129125
)
130126

131-
load("@trtorch_py_deps//:requirements.bzl", "pip_install")
132-
pip_install()
133-
134-
pip3_import(
127+
pip_install(
135128
name = "py_test_deps",
136-
requirements = "//tests/py:requirements.txt"
129+
requirements = "//tests/py:requirements.txt",
137130
)
138131

139-
load("@py_test_deps//:requirements.bzl", "pip_install")
140-
pip_install()
141-
142-
pip3_import(
143-
name = "pylinter_deps",
144-
requirements = "//tools/linter:requirements.txt",
132+
pip_install(
133+
name = "pylinter_deps",
134+
requirements = "//tools/linter:requirements.txt",
145135
)
146-
147-
load("@pylinter_deps//:requirements.bzl", "pip_install")
148-
pip_install()

core/compiler.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace core {
3030
void AddEngineToGraph(
3131
torch::jit::script::Module mod,
3232
std::shared_ptr<torch::jit::Graph>& g,
33-
std::string& serialized_engine,
33+
const std::string& serialized_engine,
3434
int engine_id = 0) {
3535
auto engine_ptr =
3636
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + std::to_string(engine_id), serialized_engine);
@@ -267,6 +267,20 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
267267
return new_mod;
268268
}
269269

270+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
271+
std::ostringstream engine_id;
272+
engine_id << reinterpret_cast<const int*>(&engine);
273+
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
274+
auto new_g = std::make_shared<torch::jit::Graph>();
275+
AddEngineToGraph(new_mod, new_g, engine);
276+
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
277+
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
278+
new_mod.type()->addMethod(new_method);
279+
new_method->setSchema(schema);
280+
281+
return new_mod;
282+
}
283+
270284
void set_device(const int gpu_id) {
271285
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
272286
}

core/compiler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
2222

2323
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2424

25+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);
26+
2527
void set_device(const int gpu_id);
2628

2729
} // namespace core

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
1313
os << "Settings requested for TensorRT engine:" \
1414
<< "\n Operating Precision: " << s.op_precision \
1515
<< "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
16+
<< "\n Truncate Long and Double: " << s.truncate_long_and_double \
1617
<< "\n Make Refittable Engine: " << s.refit \
1718
<< "\n Debuggable Engine: " << s.debug \
1819
<< "\n Strict Types: " << s.strict_types \

core/conversion/evaluators/aten.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,21 @@ auto aten_registrations TRTORCH_UNUSED =
468468
})})
469469
.evaluator({c10::Symbol::fromQualString("aten::floor"),
470470
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
471-
auto el = args.at(n->input(0)).unwrapToDouble();
472-
473-
return static_cast<int64_t>(std::floor(el));
471+
if (args.at(n->input(0)).IValue()->isInt()) {
472+
auto el = args.at(n->input(0)).unwrapToInt();
473+
return static_cast<int64_t>(std::floor(el));
474+
} else if (args.at(n->input(0)).IValue()->isDouble()) {
475+
auto el = args.at(n->input(0)).unwrapToDouble();
476+
return static_cast<int64_t>(std::floor(el));
477+
} else {
478+
TRTORCH_THROW_ERROR(
479+
"Unimplemented data type for aten::floor evaluator: "
480+
<< args.at(n->input(0)).IValue()->type()->str());
481+
return {};
482+
}
474483
},
475484
EvalOptions().validSchemas({
485+
"aten::floor.int(int a) -> (int)",
476486
"aten::floor.float(float a) -> (int)",
477487
})})
478488
.evaluator({c10::Symbol::fromQualString("aten::warn"),

cpp/api/include/trtorch/trtorch.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,21 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
511511
const torch::jit::Module& module,
512512
std::string method_name,
513513
CompileSpec info);
514+
515+
/**
516+
* @brief Take a previously created TensorRT engine and embed it in
517+
* in a TorchScript module
518+
*
519+
* @param engine: std::string - Pre-built serialized TensorRT engine
520+
*
521+
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
522+
* module. Registers execution of the engine as the forward method of the module
523+
* Forward is defined as: forward(Tensor[]) -> Tensor[]
524+
*
525+
* @return: A new module trageting a TensorRT engine
526+
*/
527+
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine);
528+
514529
/**
515530
* @brief Set gpu device id
516531
*

cpp/api/src/trtorch.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
3131
return core::CompileGraph(module, to_internal_compile_spec(info));
3232
}
3333

34+
torch::jit::Module EmbedEngineInNewModule(const std::string& engine) {
35+
return core::EmbedEngineInNewModule(engine);
36+
}
37+
3438
std::string get_build_info() {
3539
auto info = core::util::get_build_info();
3640
return std::string("TRTorch Version: ") + TRTORCH_VERSION + '\n' + info;

0 commit comments

Comments
 (0)