You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Adding multi gpu support for DPR inference (#1414)
* Added support for Multi-GPU inference to DPR including benchmark
* fixed multi gpu
* added batch size to benchmark to better reflect multi gpu capabilities
* remove unnecessary entry in config.json
* fixed typos
* fixed config name
* update benchmark to use DEVICES constant
* changed multi gpu parameters and updated docstring
* adds silent fallback on cpu
* update doc string, warning and config
Co-authored-by: Michel Bartels <[email protected]>
Co-authored-by: Malte Pietsch <[email protected]>
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@@ -82,8 +84,8 @@ def __init__(self,
82
84
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
83
85
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
84
86
:param top_k: How many documents to return per query.
85
-
:param use_gpu: Whether to use gpu or not
86
-
:param batch_size: Number of questions or passages to encode at once
87
+
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
88
+
:param batch_size: Number of questions or passages to encode at once. In case of multiple gpus, this will be the total batch size.
87
89
:param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
88
90
This is the approach used in the original paper and is likely to improve performance if your
89
91
titles contain meaningful information for retrieval (topic, entities etc.) .
@@ -99,6 +101,8 @@ def __init__(self,
99
101
Increase if errors like "encoded data exceeds max_size ..." come up
100
102
:param progress_bar: Whether to show a tqdm progress bar or not.
101
103
Can be helpful to disable in production deployments to keep the logs clean.
104
+
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
105
+
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
102
106
"""
103
107
104
108
# save init parameters to enable export of component config as YAML
0 commit comments