Skip to content

Commit 6550195

Browse files
committed
adding SBSA to the warning log and installing MPI libs for the distributed tests
1 parent 43b5ade commit 6550195

File tree

5 files changed

+4
-13
lines changed

5 files changed

+4
-13
lines changed

.github/scripts/install-mpi-linux-x86.sh

Lines changed: 0 additions & 4 deletions
This file was deleted.

.github/scripts/install-torch-tensorrt.sh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@ if [[ $(uname -m) == "aarch64" ]]; then
1212
install_cuda_aarch64
1313
fi
1414

15-
if [[ "$(uname -s)" == "Linux" && "$(uname -m)" == "x86_64" ]]; then
16-
# install MPI for Linux x86_64
17-
source .github/scripts/install-mpi-linux-x86.sh
18-
install_mpi_linux_x86
19-
fi
20-
2115
# Install all the dependencies required for Torch-TensorRT
2216
pip install --pre -r ${PWD}/tests/py/requirements.txt
2317
# dependencies in the tests/py/requirements.txt might install a different version of torch or torchvision

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ jobs:
363363
export USE_HOST_DEPS=1
364364
export CI_BUILD=1
365365
export USE_TRTLLM_PLUGINS=1
366+
dnf install -y mpich mpich-devel openmpi openmpi-devel
366367
pushd .
367368
cd tests/py
368369
cd dynamo

py/torch_tensorrt/dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def is_platform_supported_for_trtllm(platform: str) -> bool:
856856
return False
857857
if "aarch64" in platform:
858858
logger.info(
859-
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices (aarch64)"
859+
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices or SBSA (aarch64)"
860860
)
861861
return False
862862
return True

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def forward(self, x):
4444
platform_str = str(Platform.current_platform()).lower()
4545

4646

47-
class TestGatherNcclOpsConverter(DispatchTestCase):
47+
class TestNcclOpsConverter(DispatchTestCase):
4848
@unittest.skipIf(
4949
"win" in platform_str or "aarch64" in platform_str,
50-
"Skipped on Windows and Jetson: NCCL backend is not supported.",
50+
"Skipped on Windows, Jetson, SBSA: NCCL backend is not supported.",
5151
)
5252
@classmethod
5353
def setUpClass(cls):

0 commit comments

Comments
 (0)