Skip to content

Commit d6ea0d5

Browse files
committed
feat: Jetson specific workspace file
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8929623 commit d6ea0d5

File tree

3 files changed

+118
-4
lines changed

3 files changed

+118
-4
lines changed

py/setup.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,18 @@ def get_git_revision_short_hash() -> str:
7272
elif version == "4.6":
7373
JETPACK_VERSION = "4.6"
7474
elif version == "5.0":
75-
JETPACK_VERSION = "4.6"
75+
JETPACK_VERSION = "5.0"
76+
7677
if not JETPACK_VERSION:
7778
warnings.warn(
78-
"Assuming jetpack version to be 4.6 or greater, if not use the --jetpack-version option"
79+
"Assuming jetpack version to be 5.0, if not use the --jetpack-version option"
80+
)
81+
JETPACK_VERSION = "5.0"
82+
83+
if not CXX11_ABI:
84+
warnings.warn(
85+
"Jetson platform detected but did not see --use-cxx11-abi option, if using a pytorch distribution provided by NVIDIA include this flag"
7986
)
80-
JETPACK_VERSION = "4.6"
8187

8288

8389
def which(program):
@@ -128,7 +134,10 @@ def build_libtorchtrt_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=F
128134
print("Jetpack version: 4.5")
129135
elif JETPACK_VERSION == "4.6":
130136
cmd.append("--platforms=//toolchains:jetpack_4.6")
131-
print("Jetpack version: >=4.6")
137+
print("Jetpack version: 4.6")
138+
elif JETPACK_VERSION == "5.0":
139+
cmd.append("--platforms=//toolchains:jetpack_5.0")
140+
print("Jetpack version: 5.0")
132141

133142
if CI_RELEASE:
134143
cmd.append("--platforms=//toolchains:ci_rhel_x86_64_linux")

toolchains/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ platform(
2626
],
2727
)
2828

29+
platform(
30+
name = "jetpack_5.0",
31+
constraint_values = [
32+
"@platforms//os:linux",
33+
"@platforms//cpu:aarch64",
34+
"@//toolchains/jetpack:4.6",
35+
],
36+
)
37+
2938
platform(
3039
name = "ci_rhel_x86_64_linux",
3140
constraint_values = [
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
workspace(name = "Torch-TensorRT")
2+
3+
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
4+
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
5+
6+
http_archive(
7+
name = "rules_python",
8+
sha256 = "778197e26c5fbeb07ac2a2c5ae405b30f6cb7ad1f5510ea6fdac03bded96cc6f",
9+
url = "https://github.com/bazelbuild/rules_python/releases/download/0.2.0/rules_python-0.2.0.tar.gz",
10+
)
11+
12+
load("@rules_python//python:pip.bzl", "pip_install")
13+
14+
http_archive(
15+
name = "rules_pkg",
16+
sha256 = "038f1caa773a7e35b3663865ffb003169c6a71dc995e39bf4815792f385d837d",
17+
urls = [
18+
"https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
19+
"https://github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
20+
],
21+
)
22+
23+
load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
24+
25+
rules_pkg_dependencies()
26+
27+
git_repository(
28+
name = "googletest",
29+
commit = "703bd9caab50b139428cea1aaff9974ebee5742e",
30+
remote = "https://github.com/google/googletest",
31+
shallow_since = "1570114335 -0400",
32+
)
33+
34+
# External dependency for torch_tensorrt if you already have precompiled binaries.
35+
local_repository(
36+
name = "torch_tensorrt",
37+
path = "/opt/conda/lib/python3.8/site-packages/torch_tensorrt",
38+
)
39+
40+
# CUDA should be installed on the system locally
41+
new_local_repository(
42+
name = "cuda",
43+
build_file = "@//third_party/cuda:BUILD",
44+
path = "/usr/local/cuda-11.4/",
45+
)
46+
47+
new_local_repository(
48+
name = "cublas",
49+
build_file = "@//third_party/cublas:BUILD",
50+
path = "/usr",
51+
)
52+
53+
####################################################################################
54+
# Locally installed dependencies (use in cases of custom dependencies or aarch64)
55+
####################################################################################
56+
57+
# NOTE: In the case you are using just the pre-cxx11-abi path or just the cxx11 abi path
58+
# with your local libtorch, just point deps at the same path to satisfy bazel.
59+
60+
# NOTE: NVIDIA's aarch64 PyTorch (python) wheel file uses the CXX11 ABI unlike PyTorch's standard
61+
# x86_64 python distribution. If using NVIDIA's version just point to the root of the package
62+
# for both versions here and do not use --config=pre-cxx11-abi
63+
64+
new_local_repository(
65+
name = "libtorch",
66+
path = "/usr/local/lib/python3.8/dist-packages/torch",
67+
build_file = "third_party/libtorch/BUILD"
68+
)
69+
70+
# NOTE: Unused on aarch64-jetson with NVIDIA provided PyTorch distribu†ion
71+
new_local_repository(
72+
name = "libtorch_pre_cxx11_abi",
73+
path = "/usr/local/lib/python3.8/dist-packages/torch",
74+
build_file = "third_party/libtorch/BUILD"
75+
)
76+
77+
new_local_repository(
78+
name = "cudnn",
79+
path = "/usr/",
80+
build_file = "@//third_party/cudnn/local:BUILD"
81+
)
82+
83+
new_local_repository(
84+
name = "tensorrt",
85+
path = "/usr/",
86+
build_file = "@//third_party/tensorrt/local:BUILD"
87+
)
88+
89+
#########################################################################
90+
# Development Dependencies (optional - comment out on aarch64)
91+
#########################################################################
92+
93+
pip_install(
94+
name = "devtools_deps",
95+
requirements = "//:requirements-dev.txt",
96+
)

0 commit comments

Comments
 (0)