Skip to content

Commit 228aab9

Browse files
authored
CU-869ahw0mw: Add argument to control data flow when saving results. (#144)
* CU-869ahw0mw: Add argument to control data flow when saving results. When is provided, the user (probably) expects the data to be saved on disk upon method call. But the current implementation forced the user to iterate over the results to force the annotation to actually happen. So this change allows the method to materialise the list internally to force the annotation to happen and results to be saved on disk. Additionally, it adds 2 other options: 1. The lazy iteration (what happens when no is provided) where the iteration of data is left to the user 2. The combined / saved and return option where the results are materialised, but also yielded. Notably, this will take up a lot of memory if/when used with large data sets * CU-869ahw0mw: Make tests run without materialising the output for multiprocessing * CU-869ahw0mw: Move DeID tests to non-deprecated method * CU-869ahw0mw: Some whitespace fixes * CU-869ahw0mw: Fix issue withe multiprocessing. The previous implementation would always consider the method a generator. And as such, the work would never be done at call time, regardless of whether or not the was provided. This commit fixes that by making the wrapper method a regular method that (sometimes) returns the iterator and other times just a (potentially empty) list. * CU-869ahw0mw: Add further tests to new functionality * CU-869ahw0mw: Fix behaviour (so it remains the same) in old test * CU-869ahw0mw: Fix test regarding generator issue * CU-869ahw0mw: Move saving (and not returning data) to a separate method * CU-869ahw0mw: Update tests accordingly as per last change
1 parent a58da4b commit 228aab9

File tree

3 files changed

+93
-7
lines changed

3 files changed

+93
-7
lines changed

medcat-v2/medcat/cat.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from concurrent.futures import ProcessPoolExecutor, as_completed, Future
77
import itertools
88
from contextlib import contextmanager
9+
from collections import deque
910

1011
import shutil
1112
import zipfile
@@ -318,6 +319,57 @@ def _mp_one_batch_per_process(
318319
# Yield all results from this batch
319320
yield from cur_results
320321

322+
def save_entities_multi_texts(
323+
self,
324+
texts: Union[Iterable[str], Iterable[tuple[str, str]]],
325+
save_dir_path: str,
326+
only_cui: bool = False,
327+
n_process: int = 1,
328+
batch_size: int = -1,
329+
batch_size_chars: int = 1_000_000,
330+
batches_per_save: int = 20,
331+
) -> None:
332+
"""Saves the resulting entities on disk and allows multiprocessing.
333+
334+
This uses `get_entities_multi_texts` under the hood. But it is designed
335+
to save the data on disk as it comes through.
336+
337+
Args:
338+
texts (Union[Iterable[str], Iterable[tuple[str, str]]]):
339+
The input text. Either an iterable of raw text or one
340+
with in the format of `(text_index, text)`.
341+
save_dir_path (str):
342+
The path where the results are saved. The directory will have
343+
a `annotated_ids.pickle` file containing the
344+
`tuple[list[str], int]` with a list of indices already saved
345+
and the number of parts already saved. In addition there will
346+
be (usually multuple) files in the `part_<num>.pickle` format
347+
with the partial outputs.
348+
only_cui (bool):
349+
Whether to only return CUIs rather than other information
350+
like start/end and annotated value. Defaults to False.
351+
n_process (int):
352+
Number of processes to use. Defaults to 1.
353+
The number of texts to batch at a time. A batch of the
354+
specified size will be given to each worker process.
355+
Defaults to -1 and in this case the character count will
356+
be used instead.
357+
batch_size_chars (int):
358+
The maximum number of characters to process in a batch.
359+
Each process will be given batch of texts with a total
360+
number of characters not exceeding this value. Defaults
361+
to 1,000,000 characters. Set to -1 to disable.
362+
"""
363+
if save_dir_path is None:
364+
raise ValueError("Need to specify a save path (`save_dir_path`), "
365+
f"got {save_dir_path}")
366+
out_iter = self.get_entities_multi_texts(
367+
texts, only_cui=only_cui, n_process=n_process,
368+
batch_size=batch_size, batch_size_chars=batch_size_chars,
369+
save_dir_path=save_dir_path, batches_per_save=batches_per_save)
370+
# NOTE: not keeping anything since it'll be saved on disk
371+
deque(out_iter, maxlen=0)
372+
321373
def get_entities_multi_texts(
322374
self,
323375
texts: Union[Iterable[str], Iterable[tuple[str, str]]],
@@ -376,6 +428,15 @@ def get_entities_multi_texts(
376428
saver = BatchAnnotationSaver(save_dir_path, batches_per_save)
377429
else:
378430
saver = None
431+
yield from self._get_entities_multi_texts(
432+
n_process=n_process, batch_iter=batch_iter, saver=saver)
433+
434+
def _get_entities_multi_texts(
435+
self,
436+
n_process: int,
437+
batch_iter: Iterator[list[tuple[str, str, bool]]],
438+
saver: Optional[BatchAnnotationSaver],
439+
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
379440
if n_process == 1:
380441
# just do in series
381442
for batch in batch_iter:

medcat-v2/tests/test_cat.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,13 +534,14 @@ def _do_mp_run_with_save(
534534
for name in self.cdb.name2info
535535
for negname in self.cdb.name2info if name != negname
536536
]
537-
out_data = list(self.cat.get_entities_multi_texts(
537+
out_data = self.cat.get_entities_multi_texts(
538538
in_data,
539539
save_dir_path=save_to,
540540
batch_size_chars=chars_per_batch,
541541
batches_per_save=batches_per_save,
542542
n_process=n_process,
543-
))
543+
)
544+
out_data = list(out_data)
544545
out_dict_all = {
545546
key: cdata for key, cdata in out_data
546547
}
@@ -658,6 +659,29 @@ def test_mp_saves_correct_data_with_3_proc(self):
658659
self.assert_correct_loaded_output(
659660
in_data, out_dict_all, all_loaded_output)
660661

662+
def test_get_entities_multi_texts_with_save_dir_lazy(self):
663+
texts = ["text1", "text2"]
664+
with tempfile.TemporaryDirectory() as tmp_dir:
665+
out = self.cat.get_entities_multi_texts(
666+
texts,
667+
save_dir_path=tmp_dir)
668+
# nothing before manual iter
669+
self.assertFalse(os.listdir(tmp_dir))
670+
out_list = list(out)
671+
# something was saved
672+
self.assertTrue(os.listdir(tmp_dir))
673+
# and something was yielded
674+
self.assertEqual(len(out_list), len(texts))
675+
676+
def test_save_entities_multi_texts(self):
677+
texts = ["text1", "text2"]
678+
with tempfile.TemporaryDirectory() as tmp_dir:
679+
self.cat.save_entities_multi_texts(
680+
texts,
681+
save_dir_path=tmp_dir)
682+
# stuff was already saved
683+
self.assertTrue(os.listdir(tmp_dir))
684+
661685

662686
class CATWithDocAddonTests(CATIncludingTests):
663687
EXAMPLE_TEXT = "Example text to tokenize"

medcat-v2/tests/utils/ner/test_deid.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,15 @@ def test_model_works_deid_text_redact(self):
213213
self.assert_deid_redact(anon_text)
214214

215215
def test_model_works_deid_multi_text_single_threaded(self):
216-
processed = self.deid_model.deid_multi_text([input_text, input_text], n_process=1)
216+
processed = self.deid_model.deid_multi_texts([input_text, input_text],
217+
n_process=1)
217218
self.assertEqual(len(processed), 2)
218219
for anon_text in processed:
219220
self.assert_deid_annotations(anon_text)
220221

221222
def test_model_works_deid_multi_text_single_threaded_redact(self):
222-
processed = self.deid_model.deid_multi_text([input_text, input_text],
223-
n_process=1, redact=True)
223+
processed = self.deid_model.deid_multi_texts([input_text, input_text],
224+
n_process=1, redact=True)
224225
self.assertEqual(len(processed), 2)
225226
for anon_text in processed:
226227
self.assert_deid_redact(anon_text)
@@ -229,7 +230,7 @@ def test_model_works_deid_multi_text_single_threaded_redact(self):
229230
@unittest.skip("Deid Multiprocess is broken. Exits the process, no errors shown")
230231
def test_model_can_multiprocess_no_redact(self):
231232

232-
processed = self.deid_model.deid_multi_text(
233+
processed = self.deid_model.deid_multi_texts(
233234
[input_text, input_text], n_process=2)
234235
self.assertEqual(len(processed), 2)
235236
for tid, new_text in enumerate(processed):
@@ -245,7 +246,7 @@ def test_model_can_multiprocess_redact(self):
245246
"""
246247
try:
247248
print("Calling test_model_can_multiprocess_redact")
248-
processed = self.deid_model.deid_multi_text(
249+
processed = self.deid_model.deid_multi_texts(
249250
[input_text, input_text], n_process=2, redact=True
250251
)
251252
print("Finished processing")

0 commit comments

Comments
 (0)