1818from medcat .storage .serialisers import serialise , AvailableSerialisers
1919from medcat .storage .serialisers import deserialise
2020from medcat .storage .serialisables import AbstractSerialisable
21+ from medcat .storage .mp_ents_save import BatchAnnotationSaver
2122from medcat .utils .fileutils import ensure_folder_if_parent
2223from medcat .utils .hasher import Hasher
2324from medcat .pipeline .pipeline import Pipeline
@@ -159,7 +160,7 @@ def get_entities(self,
159160 def _mp_worker_func (
160161 self ,
161162 texts_and_indices : list [tuple [str , str , bool ]]
162- ) -> list [tuple [str , str , Union [dict , Entities , OnlyCUIEntities ]]]:
163+ ) -> list [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
163164 # NOTE: this is needed for subprocess as otherwise they wouldn't have
164165 # any of these set
165166 # NOTE: these need to by dynamic in case the extra's aren't included
@@ -180,7 +181,7 @@ def _mp_worker_func(
180181 elif has_rel_cat and isinstance (addon , RelCATAddon ):
181182 addon ._rel_cat ._init_data_paths ()
182183 return [
183- (text , text_index , self .get_entities (text , only_cui = only_cui ))
184+ (text_index , self .get_entities (text , only_cui = only_cui ))
184185 for text , text_index , only_cui in texts_and_indices ]
185186
186187 def _generate_batches_by_char_length (
@@ -256,7 +257,8 @@ def _mp_one_batch_per_process(
256257 self ,
257258 executor : ProcessPoolExecutor ,
258259 batch_iter : Iterator [list [tuple [str , str , bool ]]],
259- external_processes : int
260+ external_processes : int ,
261+ saver : Optional [BatchAnnotationSaver ],
260262 ) -> Iterator [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
261263 futures : list [Future ] = []
262264 # submit batches, one for each external processes
@@ -269,16 +271,16 @@ def _mp_one_batch_per_process(
269271 break
270272 if not futures :
271273 # NOTE: if there wasn't any data, we didn't process anything
272- return
274+ raise OutOfDataException ()
273275 # Main process works on next batch while workers are busy
274276 main_batch : Optional [list [tuple [str , str , bool ]]]
275277 try :
276278 main_batch = next (batch_iter )
277279 main_results = self ._mp_worker_func (main_batch )
278-
280+ if saver :
281+ saver (main_results )
279282 # Yield main process results immediately
280- for result in main_results :
281- yield result [1 ], result [2 ]
283+ yield from main_results
282284
283285 except StopIteration :
284286 main_batch = None
@@ -295,20 +297,12 @@ def _mp_one_batch_per_process(
295297 done_future = next (as_completed (futures ))
296298 futures .remove (done_future )
297299
298- # Yield all results from this batch
299- for result in done_future . result () :
300- yield result [ 1 ], result [ 2 ]
300+ cur_results = done_future . result ()
301+ if saver :
302+ saver ( cur_results )
301303
302- # Submit next batch to keep workers busy
303- try :
304- batch = next (batch_iter )
305- futures .append (
306- executor .submit (self ._mp_worker_func , batch ))
307- except StopIteration :
308- # NOTE: if there's nothing to batch, we've got nothing
309- # to submit in terms of new work to the workers,
310- # but we may still have some futures to wait for
311- pass
304+ # Yield all results from this batch
305+ yield from cur_results
312306
313307 def get_entities_multi_texts (
314308 self ,
@@ -317,6 +311,8 @@ def get_entities_multi_texts(
317311 n_process : int = 1 ,
318312 batch_size : int = - 1 ,
319313 batch_size_chars : int = 1_000_000 ,
314+ save_dir_path : Optional [str ] = None ,
315+ batches_per_save : int = 20 ,
320316 ) -> Iterator [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
321317 """Get entities from multiple texts (potentially in parallel).
322318
@@ -343,6 +339,16 @@ def get_entities_multi_texts(
343339 Each process will be given batch of texts with a total
344340 number of characters not exceeding this value. Defaults
345341 to 1,000,000 characters. Set to -1 to disable.
342+ save_dir_path (Optional[str]):
343+ The path to where (if specified) the results are saved.
344+ The directory will have a `annotated_ids.pickle` file
345+ containing the tuple[list[str], int] with a list of
346+ indices already saved and then umber of parts already saved.
347+ In addition there will be (usually multuple) files in the
348+ `part_<num>.pickle` format with the partial outputs.
349+ batches_per_save (int):
350+ The number of patches to save (if `save_dir_path` is specified)
351+ at once. Defaults to 20.
346352
347353 Yields:
348354 Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
@@ -352,15 +358,27 @@ def get_entities_multi_texts(
352358 Union [Iterator [str ], Iterator [tuple [str , str ]]], iter (texts ))
353359 batch_iter = self ._generate_batches (
354360 text_iter , batch_size , batch_size_chars , only_cui )
361+ if save_dir_path :
362+ saver = BatchAnnotationSaver (save_dir_path , batches_per_save )
363+ else :
364+ saver = None
355365 if n_process == 1 :
356366 # just do in series
357367 for batch in batch_iter :
358- for _ , text_index , result in self ._mp_worker_func (batch ):
359- yield text_index , result
368+ batch_results = self ._mp_worker_func (batch )
369+ if saver is not None :
370+ saver (batch_results )
371+ yield from batch_results
372+ if saver :
373+ # save remainder
374+ saver ._save_cache ()
360375 return
361376
362377 with self ._no_usage_monitor_exit_flushing ():
363- yield from self ._multiprocess (n_process , batch_iter )
378+ yield from self ._multiprocess (n_process , batch_iter , saver )
379+ if saver :
380+ # save remainder
381+ saver ._save_cache ()
364382
365383 @contextmanager
366384 def _no_usage_monitor_exit_flushing (self ):
@@ -379,7 +397,8 @@ def _no_usage_monitor_exit_flushing(self):
379397
380398 def _multiprocess (
381399 self , n_process : int ,
382- batch_iter : Iterator [list [tuple [str , str , bool ]]]
400+ batch_iter : Iterator [list [tuple [str , str , bool ]]],
401+ saver : Optional [BatchAnnotationSaver ],
383402 ) -> Iterator [tuple [str , Union [dict , Entities , OnlyCUIEntities ]]]:
384403 external_processes = n_process - 1
385404 if self .FORCE_SPAWN_MP :
@@ -390,8 +409,12 @@ def _multiprocess(
390409 "libraries using threads or native extensions." )
391410 mp .set_start_method ("spawn" , force = True )
392411 with ProcessPoolExecutor (max_workers = external_processes ) as executor :
393- yield from self ._mp_one_batch_per_process (
394- executor , batch_iter , external_processes )
412+ while True :
413+ try :
414+ yield from self ._mp_one_batch_per_process (
415+ executor , batch_iter , external_processes , saver = saver )
416+ except OutOfDataException :
417+ break
395418
396419 def _get_entity (self , ent : MutableEntity ,
397420 doc_tokens : list [str ],
@@ -737,7 +760,6 @@ def load_addons(
737760 ]
738761 return [(addon .full_name , addon ) for addon in loaded_addons ]
739762
740-
741763 @overload
742764 def get_model_card (self , as_dict : Literal [True ]) -> ModelCard :
743765 pass
@@ -794,3 +816,7 @@ def __eq__(self, other: Any) -> bool:
794816 def add_addon (self , addon : AddonComponent ) -> None :
795817 self .config .components .addons .append (addon .config )
796818 self ._pipeline .add_addon (addon )
819+
820+
821+ class OutOfDataException (ValueError ):
822+ pass
0 commit comments