Skip to content

Commit ea861af

Browse files
enable back jetpack build (#3720)
1 parent 7fb0d3a commit ea861af

File tree

15 files changed

+131
-121
lines changed

15 files changed

+131
-121
lines changed

.github/workflows/build-test-linux-aarch64-jetpack.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
name: Build and test Linux aarch64 wheels for Jetpack
22

33
on:
4-
# pull_request:
5-
# push:
6-
# branches:
7-
# - main
8-
# - nightly
9-
# - release/*
10-
# tags:
11-
# # NOTE: Binary build pipelines should only get triggered on release candidate builds
12-
# # Release candidate tags look like: v1.11.0-rc1
13-
# - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
- nightly
9+
- release/*
10+
tags:
11+
# NOTE: Binary build pipelines should only get triggered on release candidate builds
12+
# Release candidate tags look like: v1.11.0-rc1
13+
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
1414
workflow_dispatch:
1515

1616
jobs:

.github/workflows/build_wheels_linux_aarch64.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ jobs:
264264
if [[ ${{ inputs.is-jetpack }} == false ]]; then
265265
${CONDA_RUN} python setup.py bdist_wheel
266266
else
267-
${CONDA_RUN} python setup.py bdist_wheel --jetpack --plat-name=linux_tegra_aarch64
267+
${CONDA_RUN} python setup.py bdist_wheel --jetpack
268268
fi
269269
- name: Repair Manylinux_2_28 Wheel
270270
shell: bash -l {0}
@@ -336,8 +336,8 @@ jobs:
336336
upload:
337337
needs: build
338338
uses: pytorch/test-infra/.github/workflows/_binary_upload.yml@main
339-
# for jetpack builds, only upload to pytorch index for nightly builds
340-
if: ${{ inputs.is-jetpack == false || (github.event_name == 'push' && startsWith(github.event.ref, 'refs/heads/nightly')) }}
339+
# for jetpack builds, do not upload to pytorch nightly index, only upload to https://pypi.jetson-ai-lab.io/ manually for each release
340+
if: ${{ inputs.is-jetpack == false }}
341341
with:
342342
repository: ${{ inputs.repository }}
343343
ref: ${{ inputs.ref }}

MODULE.bazel

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,9 @@ http_archive(
9090
http_archive(
9191
name = "torch_l4t",
9292
build_file = "@//third_party/libtorch:BUILD",
93-
sha256 = "6eff643c0a7acda92734cc798338f733ff35c7df1a4434576f5ff7c66fc97319",
9493
strip_prefix = "torch",
9594
type = "zip",
96-
urls = ["https://pypi.jetson-ai-lab.dev/jp6/cu126/+f/6ef/f643c0a7acda9/torch-2.7.0-cp310-cp310-linux_aarch64.whl"],
95+
urls = ["https://pypi.jetson-ai-lab.io/jp6/cu126/+f/62a/1beee9f2f1470/torch-2.8.0-cp310-cp310-linux_aarch64.whl"],
9796
)
9897

9998
# Download these tarballs manually from the NVIDIA website

docsrc/getting_started/jetpack.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,22 @@ Build Environment Setup
9090
.. code-block:: sh
9191
9292
# Can only install the torch and torchvision wheel from the JPL repo which is built specifically for JetPack 6.2
93-
python -m pip install torch==2.7.0 torchvision==0.22.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
93+
python -m pip install torch==2.8.0 torchvision==0.23.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126
9494
9595
9696
Building the Wheel
9797
==================
9898

9999
.. code-block:: sh
100-
python setup.py bdist_wheel
100+
python setup.py bdist_wheel --jetpack
101101
102102
Installation
103103
============
104104

105105
.. code-block:: sh
106-
# you will be able to find the wheel in the dist directory, has platform name linux_tegra_aarch64
106+
# you will be able to find the wheel in the dist directory
107107
cd dist
108-
python -m pip install torch_tensorrt-2.8.0.dev0+d8318d8fc-cp310-cp310-linux_tegra_aarch64.whl
108+
python -m pip install torch_tensorrt-2.8.0.dev0+d8318d8fc-cp310-cp310-linux_aarch64.whl
109109
110110
Post-Installation Verification
111111
==============================

packaging/pre_build_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pip uninstall -y torch torchvision
4343

4444
if [[ ${IS_JETPACK} == true ]]; then
4545
# install torch 2.7 for jp6.2
46-
pip install torch==2.7.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
46+
pip install torch==2.8.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126/
4747
else
4848
TORCH=$(grep "^torch>" py/requirements.txt)
4949
INDEX_URL=https://download.pytorch.org/whl/${CHANNEL}/${CU_VERSION}

py/torch_tensorrt/_enums.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tensorrt as trt
99
import torch
1010
from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
11+
from torch_tensorrt._utils import is_tensorrt_version_supported
1112

1213

1314
class dtype(Enum):
@@ -199,8 +200,6 @@ def _from(
199200
return dtype.i8
200201
elif t == trt.DataType.FP8:
201202
return dtype.f8
202-
elif t == trt.DataType.FP4:
203-
return dtype.fp4
204203
elif t == trt.DataType.INT32:
205204
return dtype.i32
206205
elif t == trt.DataType.INT64:
@@ -214,6 +213,9 @@ def _from(
214213
elif t == trt.DataType.BF16:
215214
return dtype.bf16
216215
else:
216+
if is_tensorrt_version_supported("10.8.0"):
217+
if t == trt.DataType.FP4:
218+
return dtype.fp4
217219
raise TypeError(
218220
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
219221
)
@@ -409,11 +411,12 @@ def to(
409411
return trt.DataType.BOOL
410412
elif self == dtype.bf16:
411413
return trt.DataType.BF16
412-
elif self == dtype.f4:
413-
return trt.DataType.FP4
414414
elif use_default:
415415
return trt.DataType.FLOAT
416416
else:
417+
if is_tensorrt_version_supported("10.8.0"):
418+
if self == dtype.f4:
419+
return trt.DataType.FP4
417420
raise TypeError("Unsupported tensorrt dtype")
418421

419422
elif t == np.dtype:

py/torch_tensorrt/_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,29 @@ def check_cross_compile_trt_win_lib() -> bool:
2424
target_lib = ".*libnvinfer_builder_resource_win.so.*"
2525
return any(re.match(target_lib, lib) for lib in loaded_libs)
2626
return False
27+
28+
29+
def is_tensorrt_version_supported(min_version: str = "10.8.0") -> bool:
30+
"""
31+
Check if the installed TensorRT version supports the specified minimum version.
32+
33+
Args:
34+
min_version (str): Minimum required TensorRT version (default: "10.8.0" for FP4 support)
35+
36+
Returns:
37+
bool: True if TensorRT version is >= min_version, False otherwise
38+
39+
Example:
40+
>>> if is_tensorrt_version_supported("10.8.0"):
41+
... # Use FP4 features
42+
... pass
43+
"""
44+
try:
45+
from importlib import metadata
46+
47+
from packaging.version import Version
48+
49+
return bool(Version(metadata.version("tensorrt")) >= Version(min_version))
50+
except (ImportError, ValueError):
51+
# If tensorrt is not installed or version cannot be determined
52+
return False

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def aot_torch_tensorrt_aten_backend(
8181
decompositions=settings_aot_autograd["decompositions"],
8282
)(gm, sample_inputs)
8383

84-
if is_tegra_platform():
84+
if not is_tegra_platform():
8585
from torch.distributed.tensor import DTensor
8686

8787
if any(isinstance(tensor, DTensor) for tensor in sample_inputs):

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99
from torch.fx.node import Argument, Node, Target
10+
from torch_tensorrt._utils import is_tensorrt_version_supported
1011
from torch_tensorrt.dynamo._settings import CompilationSettings
1112
from torch_tensorrt.dynamo._SourceIR import SourceIR
1213
from torch_tensorrt.dynamo.conversion import impl
@@ -619,40 +620,41 @@ def aten_ops_quantize_op(
619620
)
620621

621622

622-
try:
623-
import modelopt.torch.quantization as mtq # noqa: F401
623+
if is_tensorrt_version_supported("10.8.0"):
624+
try:
625+
import modelopt.torch.quantization as mtq # noqa: F401
624626

625-
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
626-
except Exception as e:
627-
_LOGGER.warning(
628-
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
629-
)
630-
else:
627+
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
628+
except Exception as e:
629+
_LOGGER.warning(
630+
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
631+
)
632+
else:
631633

632-
@dynamo_tensorrt_converter(
633-
torch.ops.tensorrt.dynamic_block_quantize_op.default,
634-
supports_dynamic_shapes=True,
635-
)
636-
def aten_ops_dynamic_block_quantize_op(
637-
ctx: ConversionContext,
638-
target: Target,
639-
args: Tuple[Argument, ...],
640-
kwargs: Dict[str, Argument],
641-
name: str,
642-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
643-
return impl.dynamic_block_quantize.quantize(
644-
ctx,
645-
target,
646-
SourceIR.ATEN,
647-
name,
648-
args[0],
649-
args[1],
650-
args[2],
651-
args[3],
652-
args[4],
653-
args[5],
654-
args[6],
634+
@dynamo_tensorrt_converter(
635+
torch.ops.tensorrt.dynamic_block_quantize_op.default,
636+
supports_dynamic_shapes=True,
655637
)
638+
def aten_ops_dynamic_block_quantize_op(
639+
ctx: ConversionContext,
640+
target: Target,
641+
args: Tuple[Argument, ...],
642+
kwargs: Dict[str, Argument],
643+
name: str,
644+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
645+
return impl.dynamic_block_quantize.quantize(
646+
ctx,
647+
target,
648+
SourceIR.ATEN,
649+
name,
650+
args[0],
651+
args[1],
652+
args[2],
653+
args[3],
654+
args[4],
655+
args[5],
656+
args[6],
657+
)
656658

657659

658660
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.fx.node import Argument, Target
2525
from torch.fx.passes.shape_prop import TensorMetadata
2626
from torch_tensorrt import _enums
27+
from torch_tensorrt._utils import is_tensorrt_version_supported
2728
from torch_tensorrt.dynamo._settings import CompilationSettings
2829
from torch_tensorrt.dynamo._SourceIR import SourceIR
2930
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -447,30 +448,35 @@ def create_constant(
447448
if torch_value is not None:
448449

449450
if torch_value.dtype == torch.uint8:
450-
if (
451-
target_quantized_type is None
452-
or target_quantized_type != trt.DataType.FP4
453-
):
454-
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
451+
if is_tensorrt_version_supported("10.8.0"):
452+
if (
453+
target_quantized_type is None
454+
or target_quantized_type != trt.DataType.FP4
455+
):
456+
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
457+
raise ValueError(
458+
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
459+
)
460+
shape[-1] = shape[-1] * 2
461+
weights = to_trt_weights(
462+
ctx,
463+
torch_value,
464+
name,
465+
"CONSTANT",
466+
"CONSTANT",
467+
dtype=trt.DataType.FP4,
468+
count=torch_value.numel() * 2,
469+
)
470+
constant = ctx.net.add_constant(
471+
shape,
472+
weights,
473+
)
474+
constant.name = name
475+
return constant.get_output(0)
476+
else:
455477
raise ValueError(
456-
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
478+
"Currently FP4 is only supported in TensorRT 10.8.0 and above"
457479
)
458-
shape[-1] = shape[-1] * 2
459-
weights = to_trt_weights(
460-
ctx,
461-
torch_value,
462-
name,
463-
"CONSTANT",
464-
"CONSTANT",
465-
dtype=trt.DataType.FP4,
466-
count=torch_value.numel() * 2,
467-
)
468-
constant = ctx.net.add_constant(
469-
shape,
470-
weights,
471-
)
472-
constant.name = name
473-
return constant.get_output(0)
474480

475481
# Record the weight in ctx for refit and cpu memory reference
476482

0 commit comments

Comments
 (0)