1- from typing import List , Optional , Dict , Union , Generator , Set , Any
1+ from typing import List , Optional , Dict , Union , Set , Any
22
33import os
44import logging
5- import multiprocessing as mp
6- from functools import partial
75from tqdm import tqdm
86import torch
97from torch .utils .data .sampler import SequentialSampler
1210from haystack .modeling .data_handler .dataloader import NamedDataLoader
1311from haystack .modeling .data_handler .processor import Processor , InferenceProcessor
1412from haystack .modeling .data_handler .samples import SampleBasket
15- from haystack .modeling .utils import (
16- grouper ,
17- initialize_device_settings ,
18- set_all_seeds ,
19- calc_chunksize ,
20- log_ascii_workers ,
21- )
13+ from haystack .modeling .utils import initialize_device_settings , set_all_seeds
2214from haystack .modeling .data_handler .inputs import QAInput
2315from haystack .modeling .model .adaptive_model import AdaptiveModel , BaseAdaptiveModel
2416from haystack .modeling .model .predictions import QAPred
@@ -70,6 +62,9 @@ def __init__(
7062 `multiprocessing.Pool` again! To do so call
7163 :func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
7264 done using this class. The garbage collector will not do this for you!
65+ .. deprecated:: 1.10
66+ This parameter has no effect; it will be removed as Inferencer multiprocessing
67+ has been deprecated.
7368 :param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
7469 :param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
7570 A list containing torch device objects and/or strings is supported (For example
@@ -113,8 +108,6 @@ def __init__(
113108 model .connect_heads_with_processor (processor .tasks , require_labels = False )
114109 set_all_seeds (42 )
115110
116- self ._set_multiprocessing_pool (num_processes )
117-
118111 @classmethod
119112 def load (
120113 cls ,
@@ -166,6 +159,9 @@ def load(
166159 `multiprocessing.Pool` again! To do so call
167160 :func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
168161 done using this class. The garbage collector will not do this for you!
162+ .. deprecated:: 1.10
163+ This parameter has no effect; it will be removed as Inferencer multiprocessing
164+ has been deprecated.
169165 :param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
170166 :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
171167 :param use_fast: (Optional, True by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
@@ -259,48 +255,6 @@ def load(
259255 devices = devices ,
260256 )
261257
262- def _set_multiprocessing_pool (self , num_processes : Optional [int ]) -> None :
263- """
264- Initialize a multiprocessing.Pool for instances of Inferencer.
265-
266- :param num_processes: the number of processes for `multiprocessing.Pool`.
267- Set to value of 1 (or 0) to disable multiprocessing.
268- Set to None to let Inferencer use all CPU cores minus one.
269- If you want to debug the Language Model, you might need to disable multiprocessing!
270- **Warning!** If you use multiprocessing you have to close the
271- `multiprocessing.Pool` again! To do so call
272- :func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
273- done using this class. The garbage collector will not do this for you!
274- :return: None
275- """
276- self .process_pool = None
277- if num_processes == 0 or num_processes == 1 : # disable multiprocessing
278- self .process_pool = None
279- else :
280- if num_processes is None : # use all CPU cores
281- if mp .cpu_count () > 3 :
282- num_processes = mp .cpu_count () - 1
283- else :
284- num_processes = mp .cpu_count ()
285- self .process_pool = mp .Pool (processes = num_processes )
286- logger .info ("Got ya %s parallel workers to do inference ..." , num_processes )
287- log_ascii_workers (n = num_processes , logger = logger )
288-
289- def close_multiprocessing_pool (self , join : bool = False ):
290- """Close the `multiprocessing.Pool` again.
291-
292- If you use multiprocessing you have to close the `multiprocessing.Pool` again!
293- To do so call this function after you are done using this class.
294- The garbage collector will not do this for you!
295-
296- :param join: wait for the worker processes to exit
297- """
298- if self .process_pool is not None :
299- self .process_pool .close ()
300- if join :
301- self .process_pool .join ()
302- self .process_pool = None
303-
304258 def save (self , path : str ):
305259 self .model .save (path )
306260 self .processor .save (path )
@@ -313,6 +267,9 @@ def inference_from_file(self, file: str, multiprocessing_chunksize: int = None,
313267
314268 :param file: path of the input file for Inference
315269 :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
270+ .. deprecated:: 1.10
271+ This parameter has no effect; it will be removed as Inferencer multiprocessing
272+ has been deprecated.
316273 :return: list of predictions
317274 """
318275 dicts = self .processor .file_to_dicts (file )
@@ -333,8 +290,11 @@ def inference_from_dicts(
333290 One dict per sample.
334291 :param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
335292 object where applicable, else it returns PredObj.to_json()
336- :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
293+ :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
337294 (only relevant if you do multiprocessing)
295+ .. deprecated:: 1.10
296+ This parameter has no effect; it will be removed as Inferencer multiprocessing
297+ has been deprecated.
338298 :return: list of predictions
339299 """
340300 # whether to aggregate predictions across different samples (e.g. for QA on long texts)
@@ -346,26 +306,8 @@ def inference_from_dicts(
346306 if len (self .model .prediction_heads ) > 0 :
347307 aggregate_preds = hasattr (self .model .prediction_heads [0 ], "aggregate_preds" )
348308
349- if self .process_pool is None : # multiprocessing disabled (helpful for debugging or using in web frameworks)
350- predictions : Any = self ._inference_without_multiprocessing (dicts , return_json , aggregate_preds )
351- return predictions
352- else : # use multiprocessing for inference
353- # Calculate values of multiprocessing_chunksize and num_processes if not supplied in the parameters.
354-
355- if multiprocessing_chunksize is None :
356- _chunk_size , _ = calc_chunksize (len (dicts ))
357- multiprocessing_chunksize = _chunk_size
358-
359- predictions = self ._inference_with_multiprocessing (
360- dicts , return_json , aggregate_preds , multiprocessing_chunksize
361- )
362-
363- self .processor .log_problematic (self .problematic_sample_ids )
364- # cast the generator to a list if it isnt already a list.
365- if type (predictions ) != list :
366- return list (predictions )
367- else :
368- return predictions
309+ predictions : Any = self ._inference_without_multiprocessing (dicts , return_json , aggregate_preds )
310+ return predictions
369311
370312 def _inference_without_multiprocessing (self , dicts : List [Dict ], return_json : bool , aggregate_preds : bool ) -> List :
371313 """
@@ -399,69 +341,6 @@ def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: boo
399341
400342 return preds_all
401343
402- def _inference_with_multiprocessing (
403- self ,
404- dicts : Union [List [Dict ], Generator [Dict , None , None ]],
405- return_json : bool ,
406- aggregate_preds : bool ,
407- multiprocessing_chunksize : int ,
408- ) -> Generator [Dict , None , None ]:
409- """
410- Implementation of inference. This method is a generator that yields the results.
411-
412- :param dicts: Samples to run inference on provided as a list of dicts or a generator object that yield dicts.
413- :param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
414- object where applicable, else it returns PredObj.to_json()
415- :param aggregate_preds: whether to aggregate predictions across different samples (e.g. for QA on long texts)
416- :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
417- :return: generator object that yield predictions
418- """
419-
420- # We group the input dicts into chunks and feed each chunk to a different process
421- # in the pool, where it gets converted to a pytorch dataset
422- if self .process_pool is not None :
423- results = self .process_pool .imap (
424- partial (self ._create_datasets_chunkwise , processor = self .processor ),
425- grouper (iterable = dicts , n = multiprocessing_chunksize ),
426- 1 ,
427- )
428-
429- # Once a process spits out a preprocessed chunk. we feed this dataset directly to the model.
430- # So we don't need to wait until all preprocessing has finished before getting first predictions.
431- for dataset , tensor_names , problematic_sample_ids , baskets in results :
432- self .problematic_sample_ids .update (problematic_sample_ids )
433- if dataset is None :
434- logger .error (
435- f"Part of the dataset could not be converted! \n "
436- f"BE AWARE: The order of predictions will not conform with the input order!"
437- )
438- else :
439- # TODO change format of formatted_preds in QA (list of dicts)
440- if aggregate_preds :
441- predictions = self ._get_predictions_and_aggregate (dataset , tensor_names , baskets )
442- else :
443- predictions = self ._get_predictions (dataset , tensor_names , baskets )
444-
445- if return_json :
446- # TODO this try catch should be removed when all tasks return prediction objects
447- try :
448- predictions = [x .to_json () for x in predictions ]
449- except AttributeError :
450- pass
451- yield from predictions
452-
453- @classmethod
454- def _create_datasets_chunkwise (cls , chunk , processor : Processor ):
455- """Convert ONE chunk of data (i.e. dictionaries) into ONE pytorch dataset.
456- This is usually executed in one of many parallel processes.
457- The resulting datasets of the processes are merged together afterwards"""
458- dicts = [d [1 ] for d in chunk ]
459- indices = [d [0 ] for d in chunk ]
460- dataset , tensor_names , problematic_sample_ids , baskets = processor .dataset_from_dicts (
461- dicts , indices , return_baskets = True
462- )
463- return dataset , tensor_names , problematic_sample_ids , baskets
464-
465344 def _get_predictions (self , dataset : Dataset , tensor_names : List , baskets ):
466345 """
467346 Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
@@ -592,20 +471,41 @@ def __init__(self, *args, **kwargs):
592471 def inference_from_dicts (
593472 self , dicts : List [dict ], return_json : bool = True , multiprocessing_chunksize : Optional [int ] = None
594473 ) -> List [QAPred ]:
474+ """
475+ :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
476+ (only relevant if you do multiprocessing)
477+ .. deprecated:: 1.10
478+ This parameter has no effect; it will be removed as Inferencer multiprocessing
479+ has been deprecated.
480+ """
595481 return Inferencer .inference_from_dicts (
596482 self , dicts , return_json = return_json , multiprocessing_chunksize = multiprocessing_chunksize
597483 )
598484
599485 def inference_from_file (
600486 self , file : str , multiprocessing_chunksize : Optional [int ] = None , return_json = True
601487 ) -> List [QAPred ]:
488+ """
489+ :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
490+ (only relevant if you do multiprocessing)
491+ .. deprecated:: 1.10
492+ This parameter has no effect; it will be removed as Inferencer multiprocessing
493+ has been deprecated.
494+ """
602495 return Inferencer .inference_from_file (
603496 self , file , return_json = return_json , multiprocessing_chunksize = multiprocessing_chunksize
604497 )
605498
606499 def inference_from_objects (
607500 self , objects : List [QAInput ], return_json : bool = True , multiprocessing_chunksize : Optional [int ] = None
608501 ) -> List [QAPred ]:
502+ """
503+ :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
504+ (only relevant if you do multiprocessing)
505+ .. deprecated:: 1.10
506+ This parameter has no effect; it will be removed as Inferencer multiprocessing
507+ has been deprecated.
508+ """
609509 dicts = [o .to_dict () for o in objects ]
610510 # TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects,
611511 # then we can and should use inference from (input) objects!
0 commit comments