Skip to content

Commit da2e8da

Browse files
MichelBartelsMichel Bartelstholor
authored
Adding multi gpu support for DPR inference (#1414)
* Added support for Multi-GPU inference to DPR including benchmark * fixed multi gpu * added batch size to benchmark to better reflect multi gpu capabilities * remove unnecessary entry in config.json * fixed typos * fixed config name * update benchmark to use DEVICES constant * changed multi gpu parameters and updated docstring * adds silent fallback on cpu * update doc string, warning and config Co-authored-by: Michel Bartels <[email protected]> Co-authored-by: Malte Pietsch <[email protected]>
1 parent 1f85969 commit da2e8da

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

haystack/retriever/dense.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import abstractmethod
33
from typing import List, Union, Optional
44
import torch
5+
from torch.nn import DataParallel
56
import numpy as np
67
from pathlib import Path
78

@@ -52,7 +53,8 @@ def __init__(self,
5253
infer_tokenizer_classes: bool = False,
5354
similarity_function: str = "dot_product",
5455
global_loss_buffer_size: int = 150000,
55-
progress_bar: bool = True
56+
progress_bar: bool = True,
57+
devices: Optional[List[Union[int, str, torch.device]]] = None
5658
):
5759
"""
5860
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@@ -82,8 +84,8 @@ def __init__(self,
8284
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
8385
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
8486
:param top_k: How many documents to return per query.
85-
:param use_gpu: Whether to use gpu or not
86-
:param batch_size: Number of questions or passages to encode at once
87+
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
88+
:param batch_size: Number of questions or passages to encode at once. In case of multiple gpus, this will be the total batch size.
8789
:param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
8890
This is the approach used in the original paper and is likely to improve performance if your
8991
titles contain meaningful information for retrieval (topic, entities etc.) .
@@ -99,6 +101,8 @@ def __init__(self,
99101
Increase if errors like "encoded data exceeds max_size ..." come up
100102
:param progress_bar: Whether to show a tqdm progress bar or not.
101103
Can be helpful to disable in production deployments to keep the logs clean.
104+
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
105+
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
102106
"""
103107

104108
# save init parameters to enable export of component config as YAML
@@ -108,9 +112,19 @@ def __init__(self,
108112
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
109113
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
110114
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
111-
similarity_function=similarity_function, progress_bar=progress_bar,
115+
similarity_function=similarity_function, progress_bar=progress_bar, devices=devices
112116
)
113117

118+
if devices is not None:
119+
self.devices = devices
120+
elif use_gpu and torch.cuda.is_available():
121+
self.devices = [torch.device(device) for device in range(torch.cuda.device_count())]
122+
else:
123+
self.devices = [torch.device("cpu")]
124+
125+
if batch_size < len(self.devices):
126+
logger.warning("Batch size is less than the number of devices. All gpus will not be utilized.")
127+
114128
self.document_store = document_store
115129
self.batch_size = batch_size
116130
self.progress_bar = progress_bar
@@ -125,8 +139,6 @@ def __init__(self,
125139
"We recommend you use dot_product instead. "
126140
"This can be set when initializing the DocumentStore")
127141

128-
self.device, _ = initialize_device_settings(use_cuda=use_gpu)
129-
130142
self.infer_tokenizer_classes = infer_tokenizer_classes
131143
tokenizers_default_classes = {
132144
"query": "DPRQuestionEncoderTokenizer",
@@ -171,11 +183,14 @@ def __init__(self,
171183
embeds_dropout_prob=0.1,
172184
lm1_output_types=["per_sequence"],
173185
lm2_output_types=["per_sequence"],
174-
device=self.device,
186+
device=self.devices[0],
175187
)
176188

177189
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
178190

191+
if len(self.devices) > 1:
192+
self.model = DataParallel(self.model, device_ids=self.devices)
193+
179194
def retrieve(self, query: str, filters: dict = None, top_k: Optional[int] = None, index: str = None) -> List[Document]:
180195
"""
181196
Scan through documents in DocumentStore and return a small number documents
@@ -234,7 +249,7 @@ def _get_predictions(self, dicts):
234249
with tqdm(total=len(data_loader)*self.batch_size, unit=" Docs", desc=f"Create embeddings", position=1,
235250
leave=False, disable=disable_tqdm) as progress_bar:
236251
for batch in data_loader:
237-
batch = {key: batch[key].to(self.device) for key in batch}
252+
batch = {key: batch[key].to(self.devices[0]) for key in batch}
238253

239254
# get logits
240255
with torch.no_grad():
@@ -371,7 +386,7 @@ def train(self,
371386
n_batches=len(data_silo.loaders["train"]),
372387
n_epochs=n_epochs,
373388
grad_acc_steps=grad_acc_steps,
374-
device=self.device,
389+
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
375390
use_amp=use_amp
376391
)
377392

@@ -384,7 +399,7 @@ def train(self,
384399
n_gpu=n_gpu,
385400
lr_schedule=lr_schedule,
386401
evaluate_every=evaluate_every,
387-
device=self.device,
402+
device=self.devices[0], # Only use first device while multi-gpu training is not implemented
388403
use_amp=use_amp
389404
)
390405

@@ -395,6 +410,8 @@ def train(self,
395410
self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
396411
self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")
397412

413+
self.model = DataParallel(self.model, device_ids=self.devices)
414+
398415
def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
399416
passage_encoder_dir: str = "passage_encoder"):
400417
"""

test/benchmarks/retriever.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
map_json = "../../docs/_src/benchmarks/retriever_map.json"
3636
speed_json = "../../docs/_src/benchmarks/retriever_speed.json"
3737

38+
DEVICES = None
39+
3840

3941
seed = 42
4042
random.seed(42)
@@ -47,7 +49,7 @@ def benchmark_indexing(n_docs_options, retriever_doc_stores, data_dir, filename_
4749
logger.info(f"##### Start indexing run: {retriever_name}, {doc_store_name}, {n_docs} docs ##### ")
4850
try:
4951
doc_store = get_document_store(doc_store_name)
50-
retriever = get_retriever(retriever_name, doc_store)
52+
retriever = get_retriever(retriever_name, doc_store, DEVICES)
5153
docs, _ = prepare_data(data_dir=data_dir,
5254
filename_gold=filename_gold,
5355
filename_negative=filename_negative,
@@ -143,7 +145,7 @@ def benchmark_querying(n_docs_options,
143145
else:
144146
similarity = "dot_product"
145147
doc_store = get_document_store(doc_store_name, similarity=similarity)
146-
retriever = get_retriever(retriever_name, doc_store)
148+
retriever = get_retriever(retriever_name, doc_store, DEVICES)
147149
add_precomputed = retriever_name in ["dpr"]
148150
# For DPR, precomputed embeddings are loaded from file
149151
docs, labels = prepare_data(data_dir=data_dir,

test/benchmarks/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_document_store(document_store_type, similarity='dot_product', index="doc
9494
raise Exception(f"No document store fixture for '{document_store_type}'")
9595
return document_store
9696

97-
def get_retriever(retriever_name, doc_store):
97+
def get_retriever(retriever_name, doc_store, devices):
9898
if retriever_name == "elastic":
9999
return ElasticsearchRetriever(doc_store)
100100
if retriever_name == "tfidf":
@@ -104,7 +104,8 @@ def get_retriever(retriever_name, doc_store):
104104
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
105105
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
106106
use_gpu=True,
107-
use_fast_tokenizers=False)
107+
use_fast_tokenizers=False,
108+
devices=devices)
108109
if retriever_name == "sentence_transformers":
109110
return EmbeddingRetriever(document_store=doc_store,
110111
embedding_model="nq-distilbert-base-v1",
@@ -166,4 +167,4 @@ def download_from_url(url: str, filepath:Union[str, Path]):
166167
logger.info(f"Downloading {url} to {filepath} ")
167168
with open(filepath, "wb") as file:
168169
http_get(url=url, temp_file=file)
169-
return filepath
170+
return filepath

0 commit comments

Comments
 (0)