5
5
import getpass
6
6
import logging
7
7
import os
8
+ import platform
8
9
import tempfile
9
10
import urllib .request
10
11
import warnings
29
30
from torch ._subclasses .fake_tensor import FakeTensor
30
31
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
31
32
from torch_tensorrt ._Device import Device
32
- from torch_tensorrt ._enums import Platform , dtype
33
+ from torch_tensorrt ._enums import dtype
33
34
from torch_tensorrt ._features import ENABLED_FEATURES
34
35
from torch_tensorrt ._Input import Input
35
36
from torch_tensorrt ._version import __tensorrt_llm_version__
@@ -101,37 +102,6 @@ class Frameworks(Enum):
101
102
}
102
103
103
104
104
- def unified_dtype_converter (
105
- dtype : Union [TRTDataType , torch .dtype , np .dtype ], to : Frameworks
106
- ) -> Union [np .dtype , torch .dtype , TRTDataType ]:
107
- """
108
- Convert TensorRT, Numpy, or Torch data types to any other of those data types.
109
-
110
- Args:
111
- dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
112
- to (Frameworks): The framework to convert the data type to.
113
-
114
- Returns:
115
- The equivalent data type in the requested framework.
116
- """
117
- assert to in Frameworks , f"Expected valid Framework for translation, got { to } "
118
- trt_major_version = int (trt .__version__ .split ("." )[0 ])
119
- if dtype in (np .int8 , torch .int8 , trt .int8 ):
120
- return DataTypeEquivalence [trt .int8 ][to ]
121
- elif trt_major_version >= 7 and dtype in (np .bool_ , torch .bool , trt .bool ):
122
- return DataTypeEquivalence [trt .bool ][to ]
123
- elif dtype in (np .int32 , torch .int32 , trt .int32 ):
124
- return DataTypeEquivalence [trt .int32 ][to ]
125
- elif dtype in (np .int64 , torch .int64 , trt .int64 ):
126
- return DataTypeEquivalence [trt .int64 ][to ]
127
- elif dtype in (np .float16 , torch .float16 , trt .float16 ):
128
- return DataTypeEquivalence [trt .float16 ][to ]
129
- elif dtype in (np .float32 , torch .float32 , trt .float32 ):
130
- return DataTypeEquivalence [trt .float32 ][to ]
131
- else :
132
- raise TypeError ("%s is not a supported dtype" % dtype )
133
-
134
-
135
105
def deallocate_module (module : torch .fx .GraphModule , delete_module : bool = True ) -> None :
136
106
"""
137
107
This is a helper function to delete the instance of module. We first move it to CPU and then
@@ -870,29 +840,33 @@ def is_tegra_platform() -> bool:
870
840
return False
871
841
872
842
873
- def is_platform_supported_for_trtllm (platform : str ) -> bool :
843
+ def is_platform_supported_for_trtllm () -> bool :
874
844
"""
875
- Checks if the current platform supports TensorRT-LLM plugins for NCCL backend
845
+ Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend.
846
+
876
847
Returns:
877
- bool: True if the platform supports TensorRT-LLM plugins for NCCL backend, False otherwise.
878
- Note:
879
- TensorRT-LLM plugins for NCCL backend are not supported on:
880
- - Windows platforms
881
- - Orin, Xavier, or Tegra devices (aarch64 architecture)
848
+ bool: True if supported, False otherwise.
882
849
850
+ Unsupported:
851
+ - Windows platforms
852
+ - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release)
883
853
"""
884
- if "windows" in platform :
854
+ system = platform .system ().lower ()
855
+ machine = platform .machine ().lower ()
856
+ release = platform .release ().lower ()
857
+
858
+ if "windows" in system :
885
859
logger .info (
886
- "TensorRT-LLM plugins for NCCL backend are not supported on Windows"
860
+ "TensorRT-LLM plugins for NCCL backend are not supported on Windows. "
887
861
)
888
862
return False
889
- if torch .cuda .is_available ():
890
- device_name = torch .cuda .get_device_name ().lower ()
891
- if any (keyword in device_name for keyword in ["orin" , "xavier" , "tegra" ]):
892
- return False
863
+
864
+ if machine == "aarch64" and "tegra" in release :
893
865
logger .info (
894
- "TensorRT-LLM plugins for NCCL backend are not supported on Jetson devices"
866
+ "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices. "
895
867
)
868
+ return False
869
+
896
870
return True
897
871
898
872
@@ -905,7 +879,7 @@ def _extracted_dir_trtllm(platform: str) -> Path:
905
879
return _cache_root () / "trtllm" / f"{ __tensorrt_llm_version__ } _{ platform } "
906
880
907
881
908
- def download_and_get_plugin_lib_path (platform : str ) -> Optional [str ]:
882
+ def download_and_get_plugin_lib_path () -> Optional [str ]:
909
883
"""
910
884
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
911
885
@@ -919,12 +893,13 @@ def download_and_get_plugin_lib_path(platform: str) -> Optional[str]:
919
893
f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -"
920
894
f"{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
921
895
)
896
+ platform_system = platform .system ().lower ()
922
897
wheel_path = _cache_root () / wheel_filename
923
- extract_dir = _extracted_dir_trtllm (platform )
898
+ extract_dir = _extracted_dir_trtllm (platform_system )
924
899
# else will never be met though
925
900
lib_filename = (
926
901
"libnvinfer_plugin_tensorrt_llm.so"
927
- if "linux" in platform
902
+ if "linux" in platform_system
928
903
else "libnvinfer_plugin_tensorrt_llm.dll"
929
904
)
930
905
# eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
@@ -1057,10 +1032,7 @@ def load_tensorrt_llm_for_nccl() -> bool:
1057
1032
Returns:
1058
1033
bool: True if the plugin was successfully loaded and initialized, False otherwise.
1059
1034
"""
1060
- # Check platform compatibility first
1061
- platform = Platform .current_platform ()
1062
- platform = str (platform ).lower ()
1063
- if not is_platform_supported_for_trtllm (platform ):
1035
+ if not is_platform_supported_for_trtllm ():
1064
1036
return False
1065
1037
plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
1066
1038
@@ -1080,6 +1052,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
1080
1052
)
1081
1053
return False
1082
1054
1083
- plugin_lib_path = download_and_get_plugin_lib_path (platform )
1055
+ plugin_lib_path = download_and_get_plugin_lib_path ()
1084
1056
return load_and_initialize_trtllm_plugin (plugin_lib_path ) # type: ignore[arg-type]
1085
1057
return False
0 commit comments