Skip to content

Commit 19b5a99

Browse files
committed
Using python package for platform detection
1 parent 27b6a31 commit 19b5a99

File tree

3 files changed

+62
-59
lines changed

3 files changed

+62
-59
lines changed

.github/workflows/build-test-linux-aarch64.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,41 @@ jobs:
356356
python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .
357357
popd
358358
359+
tests-py-distributed:
360+
name: Test dynamo distributed [Python]
361+
needs: [filter-matrix, build]
362+
if: false
363+
strategy:
364+
fail-fast: false
365+
matrix:
366+
include:
367+
- repository: pytorch/tensorrt
368+
package-name: torch_tensorrt
369+
pre-script: packaging/pre_build_script.sh
370+
post-script: packaging/post_build_script.sh
371+
smoke-test-script: packaging/smoke_test_script.sh
372+
uses: ./.github/workflows/linux-test.yml
373+
with:
374+
job-name: tests-py-dynamo-distributed
375+
repository: "pytorch/tensorrt"
376+
ref: ""
377+
test-infra-repository: pytorch/test-infra
378+
test-infra-ref: main
379+
build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
380+
pre-script: ${{ matrix.pre-script }}
381+
script: |
382+
set -euo pipefail
383+
export USE_HOST_DEPS=1
384+
export CI_BUILD=1
385+
export USE_TRTLLM_PLUGINS=1
386+
dnf install -y mpich mpich-devel openmpi openmpi-devel
387+
pushd .
388+
cd tests/py
389+
cd dynamo
390+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_distributed_test_results.xml distributed/test_nccl_ops.py
391+
popd
392+
393+
359394
concurrency:
360395
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
361396
cancel-in-progress: true

py/torch_tensorrt/dynamo/utils.py

Lines changed: 26 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import getpass
66
import logging
77
import os
8+
import platform
89
import tempfile
910
import urllib.request
1011
import warnings
@@ -29,7 +30,7 @@
2930
from torch._subclasses.fake_tensor import FakeTensor
3031
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
3132
from torch_tensorrt._Device import Device
32-
from torch_tensorrt._enums import Platform, dtype
33+
from torch_tensorrt._enums import dtype
3334
from torch_tensorrt._features import ENABLED_FEATURES
3435
from torch_tensorrt._Input import Input
3536
from torch_tensorrt._version import __tensorrt_llm_version__
@@ -101,37 +102,6 @@ class Frameworks(Enum):
101102
}
102103

103104

104-
def unified_dtype_converter(
105-
dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks
106-
) -> Union[np.dtype, torch.dtype, TRTDataType]:
107-
"""
108-
Convert TensorRT, Numpy, or Torch data types to any other of those data types.
109-
110-
Args:
111-
dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
112-
to (Frameworks): The framework to convert the data type to.
113-
114-
Returns:
115-
The equivalent data type in the requested framework.
116-
"""
117-
assert to in Frameworks, f"Expected valid Framework for translation, got {to}"
118-
trt_major_version = int(trt.__version__.split(".")[0])
119-
if dtype in (np.int8, torch.int8, trt.int8):
120-
return DataTypeEquivalence[trt.int8][to]
121-
elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):
122-
return DataTypeEquivalence[trt.bool][to]
123-
elif dtype in (np.int32, torch.int32, trt.int32):
124-
return DataTypeEquivalence[trt.int32][to]
125-
elif dtype in (np.int64, torch.int64, trt.int64):
126-
return DataTypeEquivalence[trt.int64][to]
127-
elif dtype in (np.float16, torch.float16, trt.float16):
128-
return DataTypeEquivalence[trt.float16][to]
129-
elif dtype in (np.float32, torch.float32, trt.float32):
130-
return DataTypeEquivalence[trt.float32][to]
131-
else:
132-
raise TypeError("%s is not a supported dtype" % dtype)
133-
134-
135105
def deallocate_module(module: torch.fx.GraphModule, delete_module: bool = True) -> None:
136106
"""
137107
This is a helper function to delete the instance of module. We first move it to CPU and then
@@ -870,29 +840,33 @@ def is_tegra_platform() -> bool:
870840
return False
871841

872842

873-
def is_platform_supported_for_trtllm(platform: str) -> bool:
843+
def is_platform_supported_for_trtllm() -> bool:
874844
"""
875-
Checks if the current platform supports TensorRT-LLM plugins for NCCL backend
845+
Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend.
846+
876847
Returns:
877-
bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise.
878-
Note:
879-
TensorRT-LLM plugins for NCCL backend are not supported on:
880-
- Windows platforms
881-
- Orin, Xavier, or Tegra devices (aarch64 architecture)
848+
bool: True if supported, False otherwise.
882849
850+
Unsupported:
851+
- Windows platforms
852+
- Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release)
883853
"""
884-
if "windows" in platform:
854+
system = platform.system().lower()
855+
machine = platform.machine().lower()
856+
release = platform.release().lower()
857+
858+
if "windows" in system:
885859
logger.info(
886-
"TensorRT-LLM plugins for NCCL backend are not supported on Windows"
860+
"TensorRT-LLM plugins for NCCL backend are not supported on Windows."
887861
)
888862
return False
889-
if torch.cuda.is_available():
890-
device_name = torch.cuda.get_device_name().lower()
891-
if any(keyword in device_name for keyword in ["orin", "xavier", "tegra"]):
892-
return False
863+
864+
if machine == "aarch64" and "tegra" in release:
893865
logger.info(
894-
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices"
866+
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices."
895867
)
868+
return False
869+
896870
return True
897871

898872

@@ -905,7 +879,7 @@ def _extracted_dir_trtllm(platform: str) -> Path:
905879
return _cache_root() / "trtllm" / f"{__tensorrt_llm_version__}_{platform}"
906880

907881

908-
def download_and_get_plugin_lib_path(platform: str) -> Optional[str]:
882+
def download_and_get_plugin_lib_path() -> Optional[str]:
909883
"""
910884
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
911885
@@ -919,12 +893,13 @@ def download_and_get_plugin_lib_path(platform: str) -> Optional[str]:
919893
f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-"
920894
f"{_WHL_CPYTHON_VERSION}-{platform}.whl"
921895
)
896+
platform_system = platform.system().lower()
922897
wheel_path = _cache_root() / wheel_filename
923-
extract_dir = _extracted_dir_trtllm(platform)
898+
extract_dir = _extracted_dir_trtllm(platform_system)
924899
# else will never be met though
925900
lib_filename = (
926901
"libnvinfer_plugin_tensorrt_llm.so"
927-
if "linux" in platform
902+
if "linux" in platform_system
928903
else "libnvinfer_plugin_tensorrt_llm.dll"
929904
)
930905
# eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
@@ -1057,10 +1032,7 @@ def load_tensorrt_llm_for_nccl() -> bool:
10571032
Returns:
10581033
bool: True if the plugin was successfully loaded and initialized, False otherwise.
10591034
"""
1060-
# Check platform compatibility first
1061-
platform = Platform.current_platform()
1062-
platform = str(platform).lower()
1063-
if not is_platform_supported_for_trtllm(platform):
1035+
if not is_platform_supported_for_trtllm():
10641036
return False
10651037
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
10661038

@@ -1080,6 +1052,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
10801052
)
10811053
return False
10821054

1083-
plugin_lib_path = download_and_get_plugin_lib_path(platform)
1055+
plugin_lib_path = download_and_get_plugin_lib_path()
10841056
return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type]
10851057
return False

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from distributed_utils import set_environment_variables_pytest
99
from parameterized import parameterized
1010
from torch.testing._internal.common_utils import run_tests
11-
from torch_tensorrt._enums import Platform
1211
from torch_tensorrt.dynamo.utils import is_platform_supported_for_trtllm
1312

1413

@@ -42,12 +41,9 @@ def forward(self, x):
4241
return torch.ops._c10d_functional.wait_tensor(out)
4342

4443

45-
platform_str = str(Platform.current_platform()).lower()
46-
47-
4844
class TestNcclOpsConverter(DispatchTestCase):
4945
@unittest.skipIf(
50-
not is_platform_supported_for_trtllm(platform_str),
46+
not is_platform_supported_for_trtllm(),
5147
"Skipped on Windows, Jetson: NCCL backend is not supported.",
5248
)
5349
@classmethod

0 commit comments

Comments
 (0)