@@ -36,6 +36,7 @@ class AbsEmbedder(ABC):
3636 passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
3737 convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
3838 Defaults to :data:`True`.
39+ multi_GPU_type (str): The type of multi-GPU inference. Defaults to :data:`"dp"`. You can choose ['dp', 'multi_process'].
3940 kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
4041 """
4142
@@ -52,6 +53,7 @@ def __init__(
5253 query_max_length : int = 512 ,
5354 passage_max_length : int = 512 ,
5455 convert_to_numpy : bool = True ,
56+ multi_GPU_type : str = 'dp' ,
5557 ** kwargs : Any ,
5658 ):
5759 query_instruction_format = query_instruction_format .replace ('\\ n' , '\n ' )
@@ -66,6 +68,8 @@ def __init__(
6668 self .query_max_length = query_max_length
6769 self .passage_max_length = passage_max_length
6870 self .convert_to_numpy = convert_to_numpy
71+ self ._multi_GPU_type = multi_GPU_type
72+ self ._dp_set = False
6973
7074 for k in kwargs :
7175 setattr (self , k , kwargs [k ])
@@ -77,6 +81,21 @@ def __init__(
7781 self .model = None
7882 self .pool = None
7983
84+ def start_dp (self ):
85+ if self ._multi_GPU_type == 'dp' and \
86+ (isinstance (self .target_devices , list ) and len (self .target_devices ) > 1 ) and \
87+ (isinstance (self .target_devices [0 ], int ) or 'cuda' in self .target_devices [0 ]) and \
88+ self ._dp_set == False :
89+
90+ if self .use_fp16 : self .model .half ()
91+ self .model = self .model .to (torch .device ("cuda" ))
92+ if isinstance (self .target_devices [0 ], int ):
93+ self .model = torch .nn .DataParallel (self .model , device_ids = self .target_devices )
94+ else :
95+ devices = [int (e .split (':' )[- 1 ].strip ()) for e in self .target_devices ]
96+ self .model = torch .nn .DataParallel (self .model , device_ids = devices )
97+ self ._dp_set = True
98+
8099 def stop_self_pool (self ):
81100 if self .pool is not None :
82101 self .stop_multi_process_pool (self .pool )
@@ -107,7 +126,10 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
107126 elif is_torch_npu_available ():
108127 return [f"npu:{ i } " for i in range (torch .npu .device_count ())]
109128 elif torch .backends .mps .is_available ():
110- return [f"mps:{ i } " for i in range (torch .mps .device_count ())]
129+ try :
130+ return [f"mps:{ i } " for i in range (torch .mps .device_count ())]
131+ except :
132+ return ["mps" ]
111133 else :
112134 return ["cpu" ]
113135 elif isinstance (devices , str ):
@@ -253,6 +275,15 @@ def encode(
253275 device = self .target_devices [0 ],
254276 ** kwargs
255277 )
278+
279+ if self ._multi_GPU_type == 'dp' :
280+ return self .encode_only (
281+ sentences ,
282+ batch_size = batch_size ,
283+ max_length = max_length ,
284+ convert_to_numpy = convert_to_numpy ,
285+ ** kwargs
286+ )
256287
257288 if self .pool is None :
258289 self .pool = self .start_multi_process_pool (AbsEmbedder ._encode_multi_process_worker )
@@ -262,6 +293,7 @@ def encode(
262293 batch_size = batch_size ,
263294 max_length = max_length ,
264295 convert_to_numpy = convert_to_numpy ,
296+ device = torch .device ("cuda" ),
265297 ** kwargs
266298 )
267299 return embeddings
@@ -284,6 +316,20 @@ def encode_single_device(
284316 """
285317 pass
286318
319+ def encode_only (
320+ self ,
321+ sentences : Union [List [str ], str ],
322+ batch_size : int = 256 ,
323+ max_length : int = 512 ,
324+ convert_to_numpy : bool = True ,
325+ device : Any = None ,
326+ ** kwargs : Any ,
327+ ):
328+ """
329+ This method should encode sentences and return embeddings on a single device.
330+ """
331+ pass
332+
287333 # adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
288334 def start_multi_process_pool (
289335 self ,
0 commit comments