Skip to content

Commit 6cb4e93

Browse files
authored
refactor: remove Inferencer multiprocessing (#3283)
1 parent b49bce9 commit 6cb4e93

File tree

2 files changed

+47
-154
lines changed

2 files changed

+47
-154
lines changed

haystack/modeling/infer.py

Lines changed: 38 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import List, Optional, Dict, Union, Generator, Set, Any
1+
from typing import List, Optional, Dict, Union, Set, Any
22

33
import os
44
import logging
5-
import multiprocessing as mp
6-
from functools import partial
75
from tqdm import tqdm
86
import torch
97
from torch.utils.data.sampler import SequentialSampler
@@ -12,13 +10,7 @@
1210
from haystack.modeling.data_handler.dataloader import NamedDataLoader
1311
from haystack.modeling.data_handler.processor import Processor, InferenceProcessor
1412
from haystack.modeling.data_handler.samples import SampleBasket
15-
from haystack.modeling.utils import (
16-
grouper,
17-
initialize_device_settings,
18-
set_all_seeds,
19-
calc_chunksize,
20-
log_ascii_workers,
21-
)
13+
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
2214
from haystack.modeling.data_handler.inputs import QAInput
2315
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
2416
from haystack.modeling.model.predictions import QAPred
@@ -70,6 +62,9 @@ def __init__(
7062
`multiprocessing.Pool` again! To do so call
7163
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
7264
done using this class. The garbage collector will not do this for you!
65+
.. deprecated:: 1.10
66+
This parameter has no effect; it will be removed as Inferencer multiprocessing
67+
has been deprecated.
7368
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
7469
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
7570
A list containing torch device objects and/or strings is supported (For example
@@ -113,8 +108,6 @@ def __init__(
113108
model.connect_heads_with_processor(processor.tasks, require_labels=False)
114109
set_all_seeds(42)
115110

116-
self._set_multiprocessing_pool(num_processes)
117-
118111
@classmethod
119112
def load(
120113
cls,
@@ -166,6 +159,9 @@ def load(
166159
`multiprocessing.Pool` again! To do so call
167160
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
168161
done using this class. The garbage collector will not do this for you!
162+
.. deprecated:: 1.10
163+
This parameter has no effect; it will be removed as Inferencer multiprocessing
164+
has been deprecated.
169165
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
170166
:param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
171167
:param use_fast: (Optional, True by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
@@ -259,48 +255,6 @@ def load(
259255
devices=devices,
260256
)
261257

262-
def _set_multiprocessing_pool(self, num_processes: Optional[int]) -> None:
263-
"""
264-
Initialize a multiprocessing.Pool for instances of Inferencer.
265-
266-
:param num_processes: the number of processes for `multiprocessing.Pool`.
267-
Set to value of 1 (or 0) to disable multiprocessing.
268-
Set to None to let Inferencer use all CPU cores minus one.
269-
If you want to debug the Language Model, you might need to disable multiprocessing!
270-
**Warning!** If you use multiprocessing you have to close the
271-
`multiprocessing.Pool` again! To do so call
272-
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
273-
done using this class. The garbage collector will not do this for you!
274-
:return: None
275-
"""
276-
self.process_pool = None
277-
if num_processes == 0 or num_processes == 1: # disable multiprocessing
278-
self.process_pool = None
279-
else:
280-
if num_processes is None: # use all CPU cores
281-
if mp.cpu_count() > 3:
282-
num_processes = mp.cpu_count() - 1
283-
else:
284-
num_processes = mp.cpu_count()
285-
self.process_pool = mp.Pool(processes=num_processes)
286-
logger.info("Got ya %s parallel workers to do inference ...", num_processes)
287-
log_ascii_workers(n=num_processes, logger=logger)
288-
289-
def close_multiprocessing_pool(self, join: bool = False):
290-
"""Close the `multiprocessing.Pool` again.
291-
292-
If you use multiprocessing you have to close the `multiprocessing.Pool` again!
293-
To do so call this function after you are done using this class.
294-
The garbage collector will not do this for you!
295-
296-
:param join: wait for the worker processes to exit
297-
"""
298-
if self.process_pool is not None:
299-
self.process_pool.close()
300-
if join:
301-
self.process_pool.join()
302-
self.process_pool = None
303-
304258
def save(self, path: str):
305259
self.model.save(path)
306260
self.processor.save(path)
@@ -313,6 +267,9 @@ def inference_from_file(self, file: str, multiprocessing_chunksize: int = None,
313267
314268
:param file: path of the input file for Inference
315269
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
270+
.. deprecated:: 1.10
271+
This parameter has no effect; it will be removed as Inferencer multiprocessing
272+
has been deprecated.
316273
:return: list of predictions
317274
"""
318275
dicts = self.processor.file_to_dicts(file)
@@ -333,8 +290,11 @@ def inference_from_dicts(
333290
One dict per sample.
334291
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
335292
object where applicable, else it returns PredObj.to_json()
336-
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
293+
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
337294
(only relevant if you do multiprocessing)
295+
.. deprecated:: 1.10
296+
This parameter has no effect; it will be removed as Inferencer multiprocessing
297+
has been deprecated.
338298
:return: list of predictions
339299
"""
340300
# whether to aggregate predictions across different samples (e.g. for QA on long texts)
@@ -346,26 +306,8 @@ def inference_from_dicts(
346306
if len(self.model.prediction_heads) > 0:
347307
aggregate_preds = hasattr(self.model.prediction_heads[0], "aggregate_preds")
348308

349-
if self.process_pool is None: # multiprocessing disabled (helpful for debugging or using in web frameworks)
350-
predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
351-
return predictions
352-
else: # use multiprocessing for inference
353-
# Calculate values of multiprocessing_chunksize and num_processes if not supplied in the parameters.
354-
355-
if multiprocessing_chunksize is None:
356-
_chunk_size, _ = calc_chunksize(len(dicts))
357-
multiprocessing_chunksize = _chunk_size
358-
359-
predictions = self._inference_with_multiprocessing(
360-
dicts, return_json, aggregate_preds, multiprocessing_chunksize
361-
)
362-
363-
self.processor.log_problematic(self.problematic_sample_ids)
364-
# cast the generator to a list if it isnt already a list.
365-
if type(predictions) != list:
366-
return list(predictions)
367-
else:
368-
return predictions
309+
predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
310+
return predictions
369311

370312
def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: bool, aggregate_preds: bool) -> List:
371313
"""
@@ -399,69 +341,6 @@ def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: boo
399341

400342
return preds_all
401343

402-
def _inference_with_multiprocessing(
403-
self,
404-
dicts: Union[List[Dict], Generator[Dict, None, None]],
405-
return_json: bool,
406-
aggregate_preds: bool,
407-
multiprocessing_chunksize: int,
408-
) -> Generator[Dict, None, None]:
409-
"""
410-
Implementation of inference. This method is a generator that yields the results.
411-
412-
:param dicts: Samples to run inference on provided as a list of dicts or a generator object that yield dicts.
413-
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
414-
object where applicable, else it returns PredObj.to_json()
415-
:param aggregate_preds: whether to aggregate predictions across different samples (e.g. for QA on long texts)
416-
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
417-
:return: generator object that yield predictions
418-
"""
419-
420-
# We group the input dicts into chunks and feed each chunk to a different process
421-
# in the pool, where it gets converted to a pytorch dataset
422-
if self.process_pool is not None:
423-
results = self.process_pool.imap(
424-
partial(self._create_datasets_chunkwise, processor=self.processor),
425-
grouper(iterable=dicts, n=multiprocessing_chunksize),
426-
1,
427-
)
428-
429-
# Once a process spits out a preprocessed chunk. we feed this dataset directly to the model.
430-
# So we don't need to wait until all preprocessing has finished before getting first predictions.
431-
for dataset, tensor_names, problematic_sample_ids, baskets in results:
432-
self.problematic_sample_ids.update(problematic_sample_ids)
433-
if dataset is None:
434-
logger.error(
435-
f"Part of the dataset could not be converted! \n"
436-
f"BE AWARE: The order of predictions will not conform with the input order!"
437-
)
438-
else:
439-
# TODO change format of formatted_preds in QA (list of dicts)
440-
if aggregate_preds:
441-
predictions = self._get_predictions_and_aggregate(dataset, tensor_names, baskets)
442-
else:
443-
predictions = self._get_predictions(dataset, tensor_names, baskets)
444-
445-
if return_json:
446-
# TODO this try catch should be removed when all tasks return prediction objects
447-
try:
448-
predictions = [x.to_json() for x in predictions]
449-
except AttributeError:
450-
pass
451-
yield from predictions
452-
453-
@classmethod
454-
def _create_datasets_chunkwise(cls, chunk, processor: Processor):
455-
"""Convert ONE chunk of data (i.e. dictionaries) into ONE pytorch dataset.
456-
This is usually executed in one of many parallel processes.
457-
The resulting datasets of the processes are merged together afterwards"""
458-
dicts = [d[1] for d in chunk]
459-
indices = [d[0] for d in chunk]
460-
dataset, tensor_names, problematic_sample_ids, baskets = processor.dataset_from_dicts(
461-
dicts, indices, return_baskets=True
462-
)
463-
return dataset, tensor_names, problematic_sample_ids, baskets
464-
465344
def _get_predictions(self, dataset: Dataset, tensor_names: List, baskets):
466345
"""
467346
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
@@ -592,20 +471,41 @@ def __init__(self, *args, **kwargs):
592471
def inference_from_dicts(
593472
self, dicts: List[dict], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None
594473
) -> List[QAPred]:
474+
"""
475+
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
476+
(only relevant if you do multiprocessing)
477+
.. deprecated:: 1.10
478+
This parameter has no effect; it will be removed as Inferencer multiprocessing
479+
has been deprecated.
480+
"""
595481
return Inferencer.inference_from_dicts(
596482
self, dicts, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize
597483
)
598484

599485
def inference_from_file(
600486
self, file: str, multiprocessing_chunksize: Optional[int] = None, return_json=True
601487
) -> List[QAPred]:
488+
"""
489+
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
490+
(only relevant if you do multiprocessing)
491+
.. deprecated:: 1.10
492+
This parameter has no effect; it will be removed as Inferencer multiprocessing
493+
has been deprecated.
494+
"""
602495
return Inferencer.inference_from_file(
603496
self, file, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize
604497
)
605498

606499
def inference_from_objects(
607500
self, objects: List[QAInput], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None
608501
) -> List[QAPred]:
502+
"""
503+
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
504+
(only relevant if you do multiprocessing)
505+
.. deprecated:: 1.10
506+
This parameter has no effect; it will be removed as Inferencer multiprocessing
507+
has been deprecated.
508+
"""
609509
dicts = [o.to_dict() for o in objects]
610510
# TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects,
611511
# then we can and should use inference from (input) objects!

test/conftest.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,22 +1115,15 @@ def adaptive_model_qa(num_processes):
11151115
"""
11161116
PyTest Fixture for a Question Answering Inferencer based on PyTorch.
11171117
"""
1118-
try:
1119-
model = Inferencer.load(
1120-
"deepset/bert-base-cased-squad2",
1121-
task_type="question_answering",
1122-
batch_size=16,
1123-
num_processes=num_processes,
1124-
gpu=False,
1125-
)
1126-
yield model
1127-
finally:
1128-
if num_processes != 0:
1129-
# close the pool
1130-
# we pass join=True to wait for all sub processes to close
1131-
# this is because below we want to test if all sub-processes
1132-
# have exited
1133-
model.close_multiprocessing_pool(join=True)
1118+
1119+
model = Inferencer.load(
1120+
"deepset/bert-base-cased-squad2",
1121+
task_type="question_answering",
1122+
batch_size=16,
1123+
num_processes=num_processes,
1124+
gpu=False,
1125+
)
1126+
yield model
11341127

11351128
# check if all workers (sub processes) are closed
11361129
current_process = psutil.Process()

0 commit comments

Comments
 (0)