Skip to content

Commit 62b6a1d

Browse files
committed
fix bug: OOM
1 parent f18806c commit 62b6a1d

File tree

8 files changed

+9
-9
lines changed

8 files changed

+9
-9
lines changed

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def encode_single_device(
257257
flag = True
258258
except RuntimeError as e:
259259
batch_size = batch_size * 3 // 4
260-
except torch.OutOfMemoryError as e:
260+
except torch.cuda.OutOfMemoryError as e:
261261
batch_size = batch_size * 3 // 4
262262

263263
# encode

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def encode_queries_single_device(
409409
flag = True
410410
except RuntimeError as e:
411411
batch_size = batch_size * 3 // 4
412-
except torch.OutOfMemoryError as e:
412+
except torch.cuda.OutOfMemoryError as e:
413413
batch_size = batch_size * 3 // 4
414414

415415
# encode
@@ -519,7 +519,7 @@ def encode_single_device(
519519
flag = True
520520
except RuntimeError as e:
521521
batch_size = batch_size * 3 // 4
522-
except torch.OutOfMemoryError as e:
522+
except torch.cuda.OutOfMemoryError as e:
523523
batch_size = batch_size * 3 // 4
524524

525525
# encode

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def encode_only(
270270
flag = True
271271
except RuntimeError as e:
272272
batch_size = batch_size * 3 // 4
273-
except torch.OutOfMemoryError as e:
273+
except torch.cuda.OutOfMemoryError as e:
274274
batch_size = batch_size * 3 // 4
275275

276276
# encode

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
406406
flag = True
407407
except RuntimeError as e:
408408
batch_size = batch_size * 3 // 4
409-
except torch.OutOfMemoryError as e:
409+
except torch.cuda.OutOfMemoryError as e:
410410
batch_size = batch_size * 3 // 4
411411

412412
# encode

FlagEmbedding/inference/reranker/decoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def compute_score_single_gpu(
412412
flag = True
413413
except RuntimeError as e:
414414
batch_size = batch_size * 3 // 4
415-
except torch.OutOfMemoryError as e:
415+
except torch.cuda.OutOfMemoryError as e:
416416
batch_size = batch_size * 3 // 4
417417

418418
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/layerwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def compute_score_single_gpu(
282282
flag = True
283283
except RuntimeError as e:
284284
batch_size = batch_size * 3 // 4
285-
except torch.OutOfMemoryError as e:
285+
except torch.cuda.OutOfMemoryError as e:
286286
batch_size = batch_size * 3 // 4
287287

288288
dataset, dataloader = None, None

FlagEmbedding/inference/reranker/decoder_only/lightweight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def compute_score_single_gpu(
368368
flag = True
369369
except RuntimeError as e:
370370
batch_size = batch_size * 3 // 4
371-
except torch.OutOfMemoryError as e:
371+
except torch.cuda.OutOfMemoryError as e:
372372
batch_size = batch_size * 3 // 4
373373

374374
all_scores = []

FlagEmbedding/inference/reranker/encoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def compute_score_single_gpu(
169169
flag = True
170170
except RuntimeError as e:
171171
batch_size = batch_size * 3 // 4
172-
except torch.OutOfMemoryError as e:
172+
except torch.cuda.OutOfMemoryError as e:
173173
batch_size = batch_size * 3 // 4
174174

175175
all_scores = []

0 commit comments

Comments
 (0)