Skip to content

Commit d656160

Browse files
committed
using python platform
1 parent 7bb105d commit d656160

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,35 @@ class Frameworks(Enum):
102102
}
103103

104104

105+
def unified_dtype_converter(
106+
dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks
107+
) -> Union[np.dtype, torch.dtype, TRTDataType]:
108+
"""
109+
Convert TensorRT, Numpy, or Torch data types to any other of those data types.
110+
Args:
111+
dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
112+
to (Frameworks): The framework to convert the data type to.
113+
Returns:
114+
The equivalent data type in the requested framework.
115+
"""
116+
assert to in Frameworks, f"Expected valid Framework for translation, got {to}"
117+
trt_major_version = int(trt.__version__.split(".")[0])
118+
if dtype in (np.int8, torch.int8, trt.int8):
119+
return DataTypeEquivalence[trt.int8][to]
120+
elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):
121+
return DataTypeEquivalence[trt.bool][to]
122+
elif dtype in (np.int32, torch.int32, trt.int32):
123+
return DataTypeEquivalence[trt.int32][to]
124+
elif dtype in (np.int64, torch.int64, trt.int64):
125+
return DataTypeEquivalence[trt.int64][to]
126+
elif dtype in (np.float16, torch.float16, trt.float16):
127+
return DataTypeEquivalence[trt.float16][to]
128+
elif dtype in (np.float32, torch.float32, trt.float32):
129+
return DataTypeEquivalence[trt.float32][to]
130+
else:
131+
raise TypeError("%s is not a supported dtype" % dtype)
132+
133+
105134
def deallocate_module(module: torch.fx.GraphModule, delete_module: bool = True) -> None:
106135
"""
107136
This is a helper function to delete the instance of module. We first move it to CPU and then
@@ -875,8 +904,12 @@ def _cache_root() -> Path:
875904
return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
876905

877906

878-
def _extracted_dir_trtllm(platform: str) -> Path:
879-
return _cache_root() / "trtllm" / f"{__tensorrt_llm_version__}_{platform}"
907+
def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
908+
return (
909+
_cache_root()
910+
/ "trtllm"
911+
/ f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}"
912+
)
880913

881914

882915
def download_and_get_plugin_lib_path() -> Optional[str]:
@@ -889,13 +922,14 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
889922
Returns:
890923
Optional[str]: Path to shared library or None if operation fails.
891924
"""
925+
platform_system = platform.system().lower()
926+
platform_machine = platform.machine().lower()
892927
wheel_filename = (
893928
f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-"
894-
f"{_WHL_CPYTHON_VERSION}-{platform}.whl"
929+
f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl"
895930
)
896-
platform_system = platform.system().lower()
897931
wheel_path = _cache_root() / wheel_filename
898-
extract_dir = _extracted_dir_trtllm(platform_system)
932+
extract_dir = _extracted_dir_trtllm(platform_system, platform_machine)
899933
# else will never be met though
900934
lib_filename = (
901935
"libnvinfer_plugin_tensorrt_llm.so"

0 commit comments

Comments
 (0)