Skip to content

Commit dbc061c

Browse files
Cherry pick jetson enablement from 2.8 release branch to main (#3765)
1 parent 452ad05 commit dbc061c

File tree

14 files changed

+357
-349
lines changed

14 files changed

+357
-349
lines changed

.github/workflows/build-test-linux-aarch64-jetpack.yml

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
name: Build and test Linux aarch64 wheels for Jetpack
22

33
on:
4-
# TODO: Uncomment this when we have a stable release
5-
# pull_request:
6-
# push:
7-
# branches:
8-
# - main
9-
# - nightly
10-
# - release/*
11-
# tags:
12-
# # NOTE: Binary build pipelines should only get triggered on release candidate builds
13-
# # Release candidate tags look like: v1.11.0-rc1
14-
# - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
- nightly
9+
- release/*
10+
tags:
11+
# NOTE: Binary build pipelines should only get triggered on release candidate builds
12+
# Release candidate tags look like: v1.11.0-rc1
13+
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
1514
workflow_dispatch:
1615

1716
jobs:

.github/workflows/build_wheels_linux_aarch64.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ jobs:
264264
if [[ ${{ inputs.is-jetpack }} == false ]]; then
265265
${CONDA_RUN} python setup.py bdist_wheel
266266
else
267-
${CONDA_RUN} python setup.py bdist_wheel --jetpack --plat-name=linux_tegra_aarch64
267+
${CONDA_RUN} python setup.py bdist_wheel --jetpack
268268
fi
269269
- name: Repair Manylinux_2_28 Wheel
270270
shell: bash -l {0}
@@ -337,8 +337,8 @@ jobs:
337337
needs: build
338338
name: upload-wheel-${{ matrix.python_version }}-${{ matrix.desired_cuda }}-${{ matrix.gpu_arch_type }}-${{ inputs.is-jetpack }}
339339
uses: pytorch/test-infra/.github/workflows/_binary_upload.yml@main
340-
# for jetpack builds, only upload to pytorch index for nightly builds
341-
if: ${{ inputs.is-jetpack == false || (github.event_name == 'push' && startsWith(github.event.ref, 'refs/heads/nightly')) }}
340+
# for jetpack builds, do not upload to pytorch nightly index, only upload to https://pypi.jetson-ai-lab.io/ manually for each release
341+
if: ${{ inputs.is-jetpack == false }}
342342
with:
343343
repository: ${{ inputs.repository }}
344344
ref: ${{ inputs.ref }}

MODULE.bazel

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,9 @@ http_archive(
9090
http_archive(
9191
name = "torch_l4t",
9292
build_file = "@//third_party/libtorch:BUILD",
93-
sha256 = "6eff643c0a7acda92734cc798338f733ff35c7df1a4434576f5ff7c66fc97319",
9493
strip_prefix = "torch",
9594
type = "zip",
96-
urls = ["https://pypi.jetson-ai-lab.dev/jp6/cu126/+f/6ef/f643c0a7acda9/torch-2.7.0-cp310-cp310-linux_aarch64.whl"],
95+
urls = ["https://pypi.jetson-ai-lab.io/jp6/cu126/+f/62a/1beee9f2f1470/torch-2.8.0-cp310-cp310-linux_aarch64.whl"],
9796
)
9897

9998
# Download these tarballs manually from the NVIDIA website

docsrc/getting_started/jetpack.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ Build Environment Setup
9090
.. code-block:: sh
9191
9292
# Can only install the torch and torchvision wheel from the JPL repo which is built specifically for JetPack 6.2
93-
python -m pip install torch==2.7.0 torchvision==0.22.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
93+
python -m pip install torch==2.8.0 torchvision==0.23.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126
9494
9595
9696
Building the Wheel
9797
==================
9898

9999
.. code-block:: sh
100-
python setup.py bdist_wheel
100+
python setup.py bdist_wheel --jetpack
101101
102102
Installation
103103
============

packaging/pre_build_script.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.26.0/bazelis
4242
pip uninstall -y torch torchvision
4343

4444
if [[ ${IS_JETPACK} == true ]]; then
45-
# install torch 2.7 for jp6.2
46-
pip install torch==2.7.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
45+
# install torch 2.8 for jp6.2
46+
pip install torch==2.8.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126/
4747
else
4848
TORCH=$(grep "^torch>" py/requirements.txt)
4949
INDEX_URL=https://download.pytorch.org/whl/${CHANNEL}/${CU_VERSION}

py/torch_tensorrt/_enums.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tensorrt as trt
99
import torch
1010
from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
11+
from torch_tensorrt._utils import is_tensorrt_version_supported
1112

1213

1314
class dtype(Enum):
@@ -199,8 +200,6 @@ def _from(
199200
return dtype.i8
200201
elif t == trt.DataType.FP8:
201202
return dtype.f8
202-
elif t == trt.DataType.FP4:
203-
return dtype.fp4
204203
elif t == trt.DataType.INT32:
205204
return dtype.i32
206205
elif t == trt.DataType.INT64:
@@ -214,6 +213,8 @@ def _from(
214213
elif t == trt.DataType.BF16:
215214
return dtype.bf16
216215
else:
216+
if is_tensorrt_version_supported("10.8.0") and t == trt.DataType.FP4:
217+
return dtype.fp4
217218
raise TypeError(
218219
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
219220
)
@@ -409,11 +410,11 @@ def to(
409410
return trt.DataType.BOOL
410411
elif self == dtype.bf16:
411412
return trt.DataType.BF16
412-
elif self == dtype.f4:
413-
return trt.DataType.FP4
414413
elif use_default:
415414
return trt.DataType.FLOAT
416415
else:
416+
if is_tensorrt_version_supported("10.8.0") and self == dtype.f4:
417+
return trt.DataType.FP4
417418
raise TypeError("Unsupported tensorrt dtype")
418419

419420
elif t == np.dtype:

py/torch_tensorrt/_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,26 @@ def check_cross_compile_trt_win_lib() -> bool:
2424
target_lib = ".*libnvinfer_builder_resource_win.so.*"
2525
return any(re.match(target_lib, lib) for lib in loaded_libs)
2626
return False
27+
28+
29+
def is_tensorrt_version_supported(min_version: str = "10.8.0") -> bool:
30+
"""
31+
Check if the installed TensorRT version supports the specified minimum version.
32+
Args:
33+
min_version (str): Minimum required TensorRT version (default: "10.8.0" for FP4 support)
34+
Returns:
35+
bool: True if TensorRT version is >= min_version, False otherwise
36+
Example:
37+
>>> if is_tensorrt_version_supported("10.8.0"):
38+
... # Use FP4 features
39+
... pass
40+
"""
41+
try:
42+
from importlib import metadata
43+
44+
from packaging.version import Version
45+
46+
return bool(Version(metadata.version("tensorrt")) >= Version(min_version))
47+
except (ImportError, ValueError):
48+
# If tensorrt is not installed or version cannot be determined
49+
return False

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99
from torch.fx.node import Argument, Node, Target
10+
from torch_tensorrt._utils import is_tensorrt_version_supported
1011
from torch_tensorrt.dynamo._settings import CompilationSettings
1112
from torch_tensorrt.dynamo._SourceIR import SourceIR
1213
from torch_tensorrt.dynamo.conversion import impl
@@ -620,40 +621,41 @@ def aten_ops_quantize_op(
620621
)
621622

622623

623-
try:
624-
import modelopt.torch.quantization as mtq # noqa: F401
624+
if is_tensorrt_version_supported("10.8.0"):
625+
try:
626+
import modelopt.torch.quantization as mtq # noqa: F401
625627

626-
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
627-
except Exception as e:
628-
_LOGGER.warning(
629-
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
630-
)
631-
else:
628+
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
629+
except Exception as e:
630+
_LOGGER.warning(
631+
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
632+
)
633+
else:
632634

633-
@dynamo_tensorrt_converter(
634-
torch.ops.tensorrt.dynamic_block_quantize_op.default,
635-
supports_dynamic_shapes=True,
636-
)
637-
def aten_ops_dynamic_block_quantize_op(
638-
ctx: ConversionContext,
639-
target: Target,
640-
args: Tuple[Argument, ...],
641-
kwargs: Dict[str, Argument],
642-
name: str,
643-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
644-
return impl.dynamic_block_quantize.quantize(
645-
ctx,
646-
target,
647-
SourceIR.ATEN,
648-
name,
649-
args[0],
650-
args[1],
651-
args[2],
652-
args[3],
653-
args[4],
654-
args[5],
655-
args[6],
635+
@dynamo_tensorrt_converter(
636+
torch.ops.tensorrt.dynamic_block_quantize_op.default,
637+
supports_dynamic_shapes=True,
656638
)
639+
def aten_ops_dynamic_block_quantize_op(
640+
ctx: ConversionContext,
641+
target: Target,
642+
args: Tuple[Argument, ...],
643+
kwargs: Dict[str, Argument],
644+
name: str,
645+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
646+
return impl.dynamic_block_quantize.quantize(
647+
ctx,
648+
target,
649+
SourceIR.ATEN,
650+
name,
651+
args[0],
652+
args[1],
653+
args[2],
654+
args[3],
655+
args[4],
656+
args[5],
657+
args[6],
658+
)
657659

658660

659661
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
ConverterRegistry,
3333
DynamoConverterImplSignature,
3434
)
35-
35+
from torch_tensorrt._utils import is_tensorrt_version_supported
3636
from ..types import Shape, TRTDataType, TRTLayer, TRTTensor
3737

3838
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -448,31 +448,35 @@ def create_constant(
448448
if torch_value is not None:
449449

450450
if torch_value.dtype == torch.uint8:
451-
if (
452-
target_quantized_type is None
453-
or target_quantized_type != trt.DataType.FP4
454-
):
455-
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
451+
if is_tensorrt_version_supported("10.8.0"):
452+
if (
453+
target_quantized_type is None
454+
or target_quantized_type != trt.DataType.FP4
455+
):
456+
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
457+
raise ValueError(
458+
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
459+
)
460+
shape[-1] = shape[-1] * 2
461+
weights = to_trt_weights(
462+
ctx,
463+
torch_value,
464+
name,
465+
"CONSTANT",
466+
"CONSTANT",
467+
dtype=trt.DataType.FP4,
468+
count=torch_value.numel() * 2,
469+
)
470+
constant = ctx.net.add_constant(
471+
shape,
472+
weights,
473+
)
474+
constant.name = name
475+
return constant.get_output(0)
476+
else:
456477
raise ValueError(
457-
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
478+
"Currently FP4 is only supported in TensorRT 10.8.0 and above"
458479
)
459-
shape[-1] = shape[-1] * 2
460-
weights = to_trt_weights(
461-
ctx,
462-
torch_value,
463-
name,
464-
"CONSTANT",
465-
"CONSTANT",
466-
dtype=trt.DataType.FP4,
467-
count=torch_value.numel() * 2,
468-
)
469-
constant = ctx.net.add_constant(
470-
shape,
471-
weights,
472-
)
473-
constant.name = name
474-
return constant.get_output(0)
475-
476480
# Record the weight in ctx for refit and cpu memory reference
477481

478482
# Convert the torch.Tensor to a trt.Weights object

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
dynamo_tensorrt_converter,
1313
)
1414
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
15-
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
16-
tensorrt_fused_nccl_all_gather_op,
17-
tensorrt_fused_nccl_reduce_scatter_op,
18-
)
1915

2016
_LOGGER: logging.Logger = logging.getLogger(__name__)
2117

2218
if load_tensorrt_llm():
19+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
20+
tensorrt_fused_nccl_all_gather_op,
21+
tensorrt_fused_nccl_reduce_scatter_op,
22+
)
2323

2424
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
2525
def fused_nccl_gather(

0 commit comments

Comments
 (0)