Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion FlagEmbedding/abc/inference/AbsEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class AbsEmbedder(ABC):
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
Defaults to :data:`True`.
multi_GPU_type (str): The type of multi-GPU inference. Defaults to :data:`"dp"`. You can choose ['dp', 'multi_process'].
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
"""

Expand All @@ -52,6 +53,7 @@ def __init__(
query_max_length: int = 512,
passage_max_length: int = 512,
convert_to_numpy: bool = True,
multi_GPU_type: str = 'dp',
**kwargs: Any,
):
query_instruction_format = query_instruction_format.replace('\\n', '\n')
Expand All @@ -66,6 +68,8 @@ def __init__(
self.query_max_length = query_max_length
self.passage_max_length = passage_max_length
self.convert_to_numpy = convert_to_numpy
self._multi_GPU_type = multi_GPU_type
self._dp_set = False

for k in kwargs:
setattr(self, k, kwargs[k])
Expand All @@ -77,6 +81,21 @@ def __init__(
self.model = None
self.pool = None

def start_dp(self):
if self._multi_GPU_type == 'dp' and \
(isinstance(self.target_devices, list) and len(self.target_devices) > 1) and \
(isinstance(self.target_devices[0], int) or 'cuda' in self.target_devices[0]) and \
self._dp_set == False:

if self.use_fp16: self.model.half()
self.model = self.model.to(torch.device("cuda"))
if isinstance(self.target_devices[0], int):
self.model = torch.nn.DataParallel(self.model, device_ids = self.target_devices)
else:
devices = [int(e.split(':')[-1].strip()) for e in self.target_devices]
self.model = torch.nn.DataParallel(self.model, device_ids = devices)
self._dp_set = True

def stop_self_pool(self):
if self.pool is not None:
self.stop_multi_process_pool(self.pool)
Expand Down Expand Up @@ -107,7 +126,10 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
elif is_torch_npu_available():
return [f"npu:{i}" for i in range(torch.npu.device_count())]
elif torch.backends.mps.is_available():
return [f"mps:{i}" for i in range(torch.mps.device_count())]
try:
return [f"mps:{i}" for i in range(torch.mps.device_count())]
except:
return ["mps"]
else:
return ["cpu"]
elif isinstance(devices, str):
Expand Down Expand Up @@ -253,6 +275,15 @@ def encode(
device=self.target_devices[0],
**kwargs
)

if self._multi_GPU_type == 'dp':
return self.encode_only(
sentences,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=convert_to_numpy,
**kwargs
)

if self.pool is None:
self.pool = self.start_multi_process_pool(AbsEmbedder._encode_multi_process_worker)
Expand All @@ -262,6 +293,7 @@ def encode(
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=convert_to_numpy,
device=torch.device("cuda"),
**kwargs
)
return embeddings
Expand All @@ -284,6 +316,20 @@ def encode_single_device(
"""
pass

def encode_only(
self,
sentences: Union[List[str], str],
batch_size: int = 256,
max_length: int = 512,
convert_to_numpy: bool = True,
device: Any = None,
**kwargs: Any,
):
"""
This method should encode sentences and return embeddings on a single device.
"""
pass

# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
def start_multi_process_pool(
self,
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/embedder/decoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def encode_single_device(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

# encode
Expand Down
4 changes: 2 additions & 2 deletions FlagEmbedding/inference/embedder/decoder_only/icl.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def encode_queries_single_device(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

# encode
Expand Down Expand Up @@ -519,7 +519,7 @@ def encode_single_device(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

# encode
Expand Down
34 changes: 33 additions & 1 deletion FlagEmbedding/inference/embedder/encoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,38 @@ def encode_single_device(
if self.use_fp16: self.model.half()

self.model.to(device)

return self.encode_only(
sentences=sentences,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=convert_to_numpy,
device=device,
**kwargs
)

@torch.no_grad()
def encode_only(
self,
sentences: Union[List[str], str],
batch_size: int = 256,
max_length: int = 512,
convert_to_numpy: bool = True,
device: Any = None,
**kwargs: Any
):
"""Encode input sentences.

Args:
sentences (Union[List[str], str]): Input sentences to encode.
batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
be a Torch Tensor. Defaults to :data:`True`.

Returns:
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
"""
self.model.eval()

input_was_string = False
Expand Down Expand Up @@ -238,7 +270,7 @@ def encode_single_device(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

# encode
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/embedder/encoder_only/m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

# encode
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/reranker/decoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def compute_score_single_gpu(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

dataset, dataloader = None, None
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/reranker/decoder_only/layerwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def compute_score_single_gpu(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

dataset, dataloader = None, None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def compute_score_single_gpu(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

all_scores = []
Expand Down
2 changes: 1 addition & 1 deletion FlagEmbedding/inference/reranker/encoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def compute_score_single_gpu(
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError as e:
batch_size = batch_size * 3 // 4

all_scores = []
Expand Down
Loading