From deabd1fedbea6c05ed4c41a9606a17802a42fee1 Mon Sep 17 00:00:00 2001 From: bvdka Date: Thu, 4 Jul 2024 14:17:10 +0200 Subject: [PATCH] show progress option for batch_recognition and batch_detection --- surya/detection.py | 8 ++++---- surya/recognition.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/surya/detection.py b/surya/detection.py index cf439acf..afc9aa98 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -24,7 +24,7 @@ def get_batch_size(): return batch_size -def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]: +def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None, show_progress=True) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = get_batch_size() @@ -51,7 +51,7 @@ def batch_detection(images: List, model: SegformerForRegressionMask, processor, batches.append(current_batch) all_preds = [] - for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"): + for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes", disable=not show_progress): batch_image_idxs = batches[batch_idx] batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs]) @@ -122,8 +122,8 @@ def parallel_get_lines(preds, orig_sizes): return result -def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: - preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size) +def batch_text_detection(images: List, model, processor, batch_size=None, show_progress=True) -> List[TextDetectionResult]: + preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size, show_progress=show_progress) results = [] if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images for i in range(len(images)): diff --git a/surya/recognition.py b/surya/recognition.py index 8ce74edb..115aa964 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -22,7 +22,7 @@ def get_batch_size(): return batch_size -def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None): +def batch_recognition(images: List, languages: List[List[str]], model, processor, batch_size=None, show_progress=True): assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(languages) @@ -60,7 +60,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor processed_batches = processor(text=[""] * len(images), images=images, lang=languages) - for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): + for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text", disable=not show_progress): batch_langs = languages[i:i+batch_size] has_math = ["_math" in lang for lang in batch_langs]