@@ -52,6 +52,7 @@ def __init__(
5252 vision_batch_size : int = 1 , # max_batch_size in VisionConfig
5353 engine_kwargs : Optional [Dict [str , Any ]] = None ,
5454 template : Optional [Template ] = None ,
55+ devices : Optional [List [int ]] = None ,
5556 ) -> None :
5657 if engine_kwargs is None :
5758 engine_kwargs = {}
@@ -74,6 +75,7 @@ def __init__(
7475 cache_max_entry_count = cache_max_entry_count ,
7576 quant_policy = quant_policy ,
7677 vision_batch_size = vision_batch_size ,
78+ devices = devices ,
7779 ** engine_kwargs )
7880
7981 self .config .torch_dtype = torch_dtype or self .model_info .torch_dtype
@@ -87,11 +89,14 @@ def _prepare_engine_kwargs(self,
8789 cache_max_entry_count : float = 0.8 ,
8890 quant_policy : int = 0 ,
8991 vision_batch_size : int = 1 ,
92+ devices : Optional [List [int ]] = None ,
9093 ** engine_kwargs ):
9194 engine_kwargs ['tp' ] = tp
9295 engine_kwargs ['session_len' ] = session_len
9396 engine_kwargs ['cache_max_entry_count' ] = cache_max_entry_count
9497 engine_kwargs ['quant_policy' ] = quant_policy
98+ if 'devices' in inspect .signature (TurbomindEngineConfig ).parameters :
99+ engine_kwargs ['devices' ] = devices
95100 backend_config = TurbomindEngineConfig (** engine_kwargs )
96101 backend_config = autoget_backend_config (self .model_dir , backend_config )
97102 self .backend_config = backend_config
0 commit comments