Skip to content

Commit 996e33b

Browse files
Add support for TensorRT-RTX (#3753)
1 parent 079400c commit 996e33b

File tree

111 files changed

+3703
-1185
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+3703
-1185
lines changed

.bazelrc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ build:cxx11_abi --define=abi=cxx11_abi
3939

4040
build:jetpack --//toolchains/dep_collection:compute_libs=jetpack
4141

42+
build:rtx --//toolchains/dep_collection:compute_libs=rtx
43+
4244
build:ci_testing --define=torchtrt_src=prebuilt --cxxopt="-DDISABLE_TEST_IN_CI" --action_env "NVIDIA_TF32_OVERRIDE=0"
4345
build:use_precompiled_torchtrt --define=torchtrt_src=prebuilt
4446

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
install_tensorrt_rtx() {
3+
if [[ ${USE_TRT_RTX} == true ]]; then
4+
TRT_RTX_VERSION=1.0.0.21
5+
install_wheel_or_not=${1:-false}
6+
echo "It is the tensorrt-rtx build, install tensorrt-rtx with install_wheel_or_not:${install_wheel_or_not}"
7+
PLATFORM=$(python -c "import sys; print(sys.platform)")
8+
echo "PLATFORM: $PLATFORM"
9+
# PYTHON_VERSION is always set in the CI environment, add this check for local testing
10+
if [ -z "$PYTHON_VERSION" ]; then
11+
echo "Error: PYTHON_VERSION environment variable is not set or empty. example format: export PYTHON_VERSION=3.11"
12+
exit 1
13+
fi
14+
15+
# python version is like 3.11, we need to convert it to cp311
16+
CPYTHON_TAG="cp${PYTHON_VERSION//./}"
17+
if [[ ${PLATFORM} == win32 ]]; then
18+
curl -L https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.0/TensorRT-RTX-${TRT_RTX_VERSION}.Windows.win10.cuda-12.9.zip -o TensorRT-RTX-${TRT_RTX_VERSION}.Windows.win10.cuda-12.9.zip
19+
unzip TensorRT-RTX-${TRT_RTX_VERSION}.Windows.win10.cuda-12.9.zip
20+
rtx_lib_dir=${PWD}/TensorRT-RTX-${TRT_RTX_VERSION}/lib
21+
export PATH=${rtx_lib_dir}:$PATH
22+
echo "PATH: $PATH"
23+
if [[ ${install_wheel_or_not} == true ]]; then
24+
pip install TensorRT-RTX-${TRT_RTX_VERSION}/python/tensorrt_rtx-${TRT_RTX_VERSION}-${CPYTHON_TAG}-none-win_amd64.whl
25+
fi
26+
else
27+
curl -L https://developer.nvidia.com/downloads/trt/rtx_sdk/secure/1.0/TensorRT-RTX-${TRT_RTX_VERSION}.Linux.x86_64-gnu.cuda-12.9.tar.gz -o TensorRT-RTX-${TRT_RTX_VERSION}.Linux.x86_64-gnu.cuda-12.9.tar.gz
28+
tar -xzf TensorRT-RTX-${TRT_RTX_VERSION}.Linux.x86_64-gnu.cuda-12.9.tar.gz
29+
rtx_lib_dir=${PWD}/TensorRT-RTX-${TRT_RTX_VERSION}/lib
30+
export LD_LIBRARY_PATH=${rtx_lib_dir}:$LD_LIBRARY_PATH
31+
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
32+
if [[ ${install_wheel_or_not} == true ]]; then
33+
pip install TensorRT-RTX-${TRT_RTX_VERSION}/python/tensorrt_rtx-${TRT_RTX_VERSION}-${CPYTHON_TAG}-none-linux_x86_64.whl
34+
fi
35+
fi
36+
else
37+
echo "It is the standard tensorrt build, skip install tensorrt-rtx"
38+
fi
39+
40+
}

.github/scripts/install-torch-tensorrt.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ pip uninstall -y torch torchvision
2121
pip install --force-reinstall --pre ${TORCHVISION} --index-url ${INDEX_URL}
2222
pip install --force-reinstall --pre ${TORCH} --index-url ${INDEX_URL}
2323

24+
if [[ ${USE_TRT_RTX} == true ]]; then
25+
source .github/scripts/install-tensorrt-rtx.sh
26+
# tensorrt-rtx is not publicly available, so we need to install the wheel from the tar ball
27+
install_wheel_or_not=true
28+
install_tensorrt_rtx ${install_wheel_or_not}
29+
fi
2430

2531
# Install Torch-TensorRT
2632
if [[ ${PLATFORM} == win32 ]]; then

.github/workflows/build-test-linux-aarch64-jetpack.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
smoke-test-script: packaging/smoke_test_script.sh
6767
package-name: torch_tensorrt
6868
name: Build torch-tensorrt whl package for jetpack
69-
uses: ./.github/workflows/build_wheels_linux_aarch64.yml
69+
uses: ./.github/workflows/build_wheels_linux.yml
7070
with:
7171
repository: ${{ matrix.repository }}
7272
ref: ""

.github/workflows/build-test-linux-aarch64.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
smoke-test-script: packaging/smoke_test_script.sh
6464
package-name: torch_tensorrt
6565
name: Build torch-tensorrt whl package for SBSA
66-
uses: ./.github/workflows/build_wheels_linux_aarch64.yml
66+
uses: ./.github/workflows/build_wheels_linux.yml
6767
with:
6868
repository: ${{ matrix.repository }}
6969
ref: ""

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
1414
workflow_dispatch:
1515

16+
1617
jobs:
1718
generate-matrix:
1819
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
@@ -60,8 +61,8 @@ jobs:
6061
post-script: packaging/post_build_script.sh
6162
smoke-test-script: packaging/smoke_test_script.sh
6263
package-name: torch_tensorrt
63-
name: Build torch-tensorrt whl package
64-
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
64+
name: Build torch-tensorrt whl package for Linux x86_64
65+
uses: ./.github/workflows/build_wheels_linux.yml
6566
with:
6667
repository: ${{ matrix.repository }}
6768
ref: ""
@@ -74,6 +75,8 @@ jobs:
7475
package-name: ${{ matrix.package-name }}
7576
smoke-test-script: ${{ matrix.smoke-test-script }}
7677
trigger-event: ${{ github.event_name }}
78+
architecture: "x86_64"
79+
use-rtx: false
7780

7881
tests-py-torchscript-fe:
7982
name: Test torchscript frontend [Python]
@@ -338,5 +341,5 @@ jobs:
338341
popd
339342
340343
concurrency:
341-
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
344+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-tensorrt-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
342345
cancel-in-progress: true

0 commit comments

Comments
 (0)