@@ -102,6 +102,35 @@ class Frameworks(Enum):
102
102
}
103
103
104
104
105
+ def unified_dtype_converter (
106
+ dtype : Union [TRTDataType , torch .dtype , np .dtype ], to : Frameworks
107
+ ) -> Union [np .dtype , torch .dtype , TRTDataType ]:
108
+ """
109
+ Convert TensorRT, Numpy, or Torch data types to any other of those data types.
110
+ Args:
111
+ dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
112
+ to (Frameworks): The framework to convert the data type to.
113
+ Returns:
114
+ The equivalent data type in the requested framework.
115
+ """
116
+ assert to in Frameworks , f"Expected valid Framework for translation, got { to } "
117
+ trt_major_version = int (trt .__version__ .split ("." )[0 ])
118
+ if dtype in (np .int8 , torch .int8 , trt .int8 ):
119
+ return DataTypeEquivalence [trt .int8 ][to ]
120
+ elif trt_major_version >= 7 and dtype in (np .bool_ , torch .bool , trt .bool ):
121
+ return DataTypeEquivalence [trt .bool ][to ]
122
+ elif dtype in (np .int32 , torch .int32 , trt .int32 ):
123
+ return DataTypeEquivalence [trt .int32 ][to ]
124
+ elif dtype in (np .int64 , torch .int64 , trt .int64 ):
125
+ return DataTypeEquivalence [trt .int64 ][to ]
126
+ elif dtype in (np .float16 , torch .float16 , trt .float16 ):
127
+ return DataTypeEquivalence [trt .float16 ][to ]
128
+ elif dtype in (np .float32 , torch .float32 , trt .float32 ):
129
+ return DataTypeEquivalence [trt .float32 ][to ]
130
+ else :
131
+ raise TypeError ("%s is not a supported dtype" % dtype )
132
+
133
+
105
134
def deallocate_module (module : torch .fx .GraphModule , delete_module : bool = True ) -> None :
106
135
"""
107
136
This is a helper function to delete the instance of module. We first move it to CPU and then
@@ -875,8 +904,12 @@ def _cache_root() -> Path:
875
904
return Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
876
905
877
906
878
- def _extracted_dir_trtllm (platform : str ) -> Path :
879
- return _cache_root () / "trtllm" / f"{ __tensorrt_llm_version__ } _{ platform } "
907
+ def _extracted_dir_trtllm (platform_system : str , platform_machine : str ) -> Path :
908
+ return (
909
+ _cache_root ()
910
+ / "trtllm"
911
+ / f"{ __tensorrt_llm_version__ } _{ platform_system } _{ platform_machine } "
912
+ )
880
913
881
914
882
915
def download_and_get_plugin_lib_path () -> Optional [str ]:
@@ -889,13 +922,14 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
889
922
Returns:
890
923
Optional[str]: Path to shared library or None if operation fails.
891
924
"""
925
+ platform_system = platform .system ().lower ()
926
+ platform_machine = platform .machine ().lower ()
892
927
wheel_filename = (
893
928
f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -"
894
- f"{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
929
+ f"{ _WHL_CPYTHON_VERSION } -{ platform_system } _ { platform_machine } .whl"
895
930
)
896
- platform_system = platform .system ().lower ()
897
931
wheel_path = _cache_root () / wheel_filename
898
- extract_dir = _extracted_dir_trtllm (platform_system )
932
+ extract_dir = _extracted_dir_trtllm (platform_system , platform_machine )
899
933
# else will never be met though
900
934
lib_filename = (
901
935
"libnvinfer_plugin_tensorrt_llm.so"
0 commit comments