Skip to content

Commit f9f673e

Browse files
authored
Merge pull request #1349 from 545999961/master
Fix Bug: OOM
2 parents 3c40623 + 62b6a1d commit f9f673e

File tree

9 files changed

+88
-10
lines changed

9 files changed

+88
-10
lines changed

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class AbsEmbedder(ABC):
3636
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
3737
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
3838
Defaults to :data:`True`.
39+
multi_GPU_type (str): The type of multi-GPU inference. Defaults to :data:`"dp"`. You can choose ['dp', 'multi_process'].
3940
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
4041
"""
4142

@@ -52,6 +53,7 @@ def __init__(
5253
query_max_length: int = 512,
5354
passage_max_length: int = 512,
5455
convert_to_numpy: bool = True,
56+
multi_GPU_type: str = 'dp',
5557
**kwargs: Any,
5658
):
5759
query_instruction_format = query_instruction_format.replace('\\n', '\n')
@@ -66,6 +68,8 @@ def __init__(
6668
self.query_max_length = query_max_length
6769
self.passage_max_length = passage_max_length
6870
self.convert_to_numpy = convert_to_numpy
71+
self._multi_GPU_type = multi_GPU_type
72+
self._dp_set = False
6973

7074
for k in kwargs:
7175
setattr(self, k, kwargs[k])
@@ -77,6 +81,21 @@ def __init__(
7781
self.model = None
7882
self.pool = None
7983

84+
def start_dp(self):
85+
if self._multi_GPU_type == 'dp' and \
86+
(isinstance(self.target_devices, list) and len(self.target_devices) > 1) and \
87+
(isinstance(self.target_devices[0], int) or 'cuda' in self.target_devices[0]) and \
88+
self._dp_set == False:
89+
90+
if self.use_fp16: self.model.half()
91+
self.model = self.model.to(torch.device("cuda"))
92+
if isinstance(self.target_devices[0], int):
93+
self.model = torch.nn.DataParallel(self.model, device_ids = self.target_devices)
94+
else:
95+
devices = [int(e.split(':')[-1].strip()) for e in self.target_devices]
96+
self.model = torch.nn.DataParallel(self.model, device_ids = devices)
97+
self._dp_set = True
98+
8099
def stop_self_pool(self):
81100
if self.pool is not None:
82101
self.stop_multi_process_pool(self.pool)
@@ -107,7 +126,10 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
107126
elif is_torch_npu_available():
108127
return [f"npu:{i}" for i in range(torch.npu.device_count())]
109128
elif torch.backends.mps.is_available():
110-
return [f"mps:{i}" for i in range(torch.mps.device_count())]
129+
try:
130+
return [f"mps:{i}" for i in range(torch.mps.device_count())]
131+
except:
132+
return ["mps"]
111133
else:
112134
return ["cpu"]
113135
elif isinstance(devices, str):
@@ -253,6 +275,15 @@ def encode(
253275
device=self.target_devices[0],
254276
**kwargs
255277
)
278+
279+
if self._multi_GPU_type == 'dp':
280+
return self.encode_only(
281+
sentences,
282+
batch_size=batch_size,
283+
max_length=max_length,
284+
convert_to_numpy=convert_to_numpy,
285+
**kwargs
286+
)
256287

257288
if self.pool is None:
258289
self.pool = self.start_multi_process_pool(AbsEmbedder._encode_multi_process_worker)
@@ -262,6 +293,7 @@ def encode(
262293
batch_size=batch_size,
263294
max_length=max_length,
264295
convert_to_numpy=convert_to_numpy,
296+
device=torch.device("cuda"),
265297
**kwargs
266298
)
267299
return embeddings
@@ -284,6 +316,20 @@ def encode_single_device(
284316
"""
285317
pass
286318

319+
def encode_only(
320+
self,
321+
sentences: Union[List[str], str],
322+
batch_size: int = 256,
323+
max_length: int = 512,
324+
convert_to_numpy: bool = True,
325+
device: Any = None,
326+
**kwargs: Any,
327+
):
328+
"""
329+
This method should encode sentences and return embeddings on a single device.
330+
"""
331+
pass
332+
287333
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
288334
def start_multi_process_pool(
289335
self,

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def encode_single_device(
257257
flag = True
258258
except RuntimeError as e:
259259
batch_size = batch_size * 3 // 4
260-
except torch.OutOfMemoryError as e:
260+
except torch.cuda.OutOfMemoryError as e:
261261
batch_size = batch_size * 3 // 4
262262

263263
# encode

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def encode_queries_single_device(
409409
flag = True
410410
except RuntimeError as e:
411411
batch_size = batch_size * 3 // 4
412-
except torch.OutOfMemoryError as e:
412+
except torch.cuda.OutOfMemoryError as e:
413413
batch_size = batch_size * 3 // 4
414414

415415
# encode
@@ -519,7 +519,7 @@ def encode_single_device(
519519
flag = True
520520
except RuntimeError as e:
521521
batch_size = batch_size * 3 // 4
522-
except torch.OutOfMemoryError as e:
522+
except torch.cuda.OutOfMemoryError as e:
523523
batch_size = batch_size * 3 // 4
524524

525525
# encode

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,38 @@ def encode_single_device(
196196
if self.use_fp16: self.model.half()
197197

198198
self.model.to(device)
199+
200+
return self.encode_only(
201+
sentences=sentences,
202+
batch_size=batch_size,
203+
max_length=max_length,
204+
convert_to_numpy=convert_to_numpy,
205+
device=device,
206+
**kwargs
207+
)
208+
209+
@torch.no_grad()
210+
def encode_only(
211+
self,
212+
sentences: Union[List[str], str],
213+
batch_size: int = 256,
214+
max_length: int = 512,
215+
convert_to_numpy: bool = True,
216+
device: Any = None,
217+
**kwargs: Any
218+
):
219+
"""Encode input sentences.
220+
221+
Args:
222+
sentences (Union[List[str], str]): Input sentences to encode.
223+
batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
224+
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
225+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
226+
be a Torch Tensor. Defaults to :data:`True`.
227+
228+
Returns:
229+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
230+
"""
199231
self.model.eval()
200232

201233
input_was_string = False
@@ -238,7 +270,7 @@ def encode_single_device(
238270
flag = True
239271
except RuntimeError as e:
240272
batch_size = batch_size * 3 // 4
241-
except torch.OutOfMemoryError as e:
273+
except torch.cuda.OutOfMemoryError as e:
242274
batch_size = batch_size * 3 // 4
243275

244276
# encode

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
406406
flag = True
407407
except RuntimeError as e:
408408
batch_size = batch_size * 3 // 4
409-
except torch.OutOfMemoryError as e:
409+
except torch.cuda.OutOfMemoryError as e:
410410
batch_size = batch_size * 3 // 4
411411

412412
# encode

FlagEmbedding/inference/reranker/decoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def compute_score_single_gpu(
412412
flag = True
413413
except RuntimeError as e:
414414
batch_size = batch_size * 3 // 4
415-
except torch.OutOfMemoryError as e:
415+
except torch.cuda.OutOfMemoryError as e:
416416
batch_size = batch_size * 3 // 4
417417

418418
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/layerwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def compute_score_single_gpu(
282282
flag = True
283283
except RuntimeError as e:
284284
batch_size = batch_size * 3 // 4
285-
except torch.OutOfMemoryError as e:
285+
except torch.cuda.OutOfMemoryError as e:
286286
batch_size = batch_size * 3 // 4
287287

288288
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/lightweight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def compute_score_single_gpu(
368368
flag = True
369369
except RuntimeError as e:
370370
batch_size = batch_size * 3 // 4
371-
except torch.OutOfMemoryError as e:
371+
except torch.cuda.OutOfMemoryError as e:
372372
batch_size = batch_size * 3 // 4
373373

374374
all_scores = []

FlagEmbedding/inference/reranker/encoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def compute_score_single_gpu(
169169
flag = True
170170
except RuntimeError as e:
171171
batch_size = batch_size * 3 // 4
172-
except torch.OutOfMemoryError as e:
172+
except torch.cuda.OutOfMemoryError as e:
173173
batch_size = batch_size * 3 // 4
174174

175175
all_scores = []

0 commit comments

Comments
 (0)