diff --git a/.gitignore b/.gitignore index 137e2d7..8481bab 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ output/* data/* temp* -test* +test-magic-pdf.py # python .ipynb_checkpoints diff --git a/app.py b/app.py new file mode 100644 index 0000000..025a2ea --- /dev/null +++ b/app.py @@ -0,0 +1,58 @@ +# refactoring pdf_extract.py +import time +import argparse + +from app_tools.config import setup_logging +from app_tools.pdf import PDFProcessor +from app_tools.layout_analysis import LayoutAnalyzer +from app_tools.formula_analysis import FormulaProcessor +from app_tools.ocr_analysis import OCRProcessor +from app_tools.table_analysis import TableProcessor +from app_tools.visualize import get_visualize +from app_tools.utils import save_file + +logger = setup_logging('app') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Process PDF files and render output images.") + parser.add_argument('--pdf', type=str, required=True, help="Path to the input PDF file") + parser.add_argument('--output', type=str, default="output", help="Output directory or filename prefix (default: 'output')") + parser.add_argument('--batch-size', type=int, default=128, help="Batch size for processing (default: 128)") + parser.add_argument('--vis', action='store_true', help="Enable visualization mode") + parser.add_argument('--render', action='store_true', help="Enable rendering mode") + args = parser.parse_args() + logger.info("Arguments: %s", args) + + logger.info('Started!') + start = time.time() + ## ======== model init ========## + analyzer = LayoutAnalyzer() + formulas = FormulaProcessor() + ocr_processor = OCRProcessor(show_log=True) + table_processor = TableProcessor() + logger.info(f'Model init done in {int(time.time() - start)}s!') + ## ======== model init ========## + + start = time.time() + pdf_processor = PDFProcessor() + all_pdfs = pdf_processor.check_pdf(args.pdf) + + for idx, single_pdf, img_list in pdf_processor.process_all_pdfs(all_pdfs): + + doc_layout_result = analyzer.detect_layout(img_list) + doc_layout_result = formulas.detect_recognize_formulas(img_list, doc_layout_result, args.batch_size) + doc_layout_result = ocr_processor.recognize_ocr(img_list, doc_layout_result) + doc_layout_result = table_processor.recognize_tables(img_list, doc_layout_result) + + basename = save_file(args.output, single_pdf, doc_layout_result) + logger.debug(f'Save file: {basename}.json') + + if args.vis: + logger.info("Visualization mode enabled") + get_visualize(img_list, doc_layout_result, args.render, args.output, basename) + else: + logger.info("Visualization mode disabled") + + logger.info(f'Finished! time cost: {int(time.time() - start)} s') + logger.info('----------------------------------------') diff --git a/app_tools/__init__.py b/app_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app_tools/config.py b/app_tools/config.py new file mode 100644 index 0000000..f1e3a09 --- /dev/null +++ b/app_tools/config.py @@ -0,0 +1,89 @@ +import os +import yaml +import pytz +import logging +import logging.config +from datetime import datetime + +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +model_configs_path = os.path.join(parent_dir, 'configs/model_configs.yaml') + +################### MODEL CONFIGS ################### + +def load_config(): + with open(model_configs_path) as f: + model_configs = yaml.load(f, Loader=yaml.FullLoader) + return model_configs + +################### LOGGING CONFIG ################### + +# TODO: Define a suitable log_file_path +log_file_path = os.path.join(parent_dir, 'app_logs.log') + +# TODO: Add to config file +TIMEZONE: str = "Asia/Shanghai" # 'Europe/Madrid' +timezone = pytz.timezone(TIMEZONE) # Specify your time zone here + + +class CustomFormatter(logging.Formatter): + def __init__(self, fmt=None, datefmt=None, tz=None): + super().__init__(fmt=fmt, datefmt=datefmt) + self.tz = tz + + def formatTime(self, record, datefmt=None): + dt = datetime.fromtimestamp(record.created, self.tz) + if datefmt: + s = dt.strftime(datefmt) + else: + try: + s = dt.isoformat(timespec='milliseconds') + except TypeError: + s = dt.isoformat() + return s + +# Basic logging configuration +LOGGING_CONFIG = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'standard': { + '()': CustomFormatter, + 'format': '%(asctime)s - %(name)s - [%(levelname)s] - %(message)s', + 'datefmt': '%Y-%m-%d %H:%M:%S', + 'tz': timezone + }, + }, + 'handlers': { + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'formatter': 'standard' + }, + 'file': { + 'level': 'INFO', + 'class': 'logging.FileHandler', + 'filename': log_file_path, + 'formatter': 'standard' + }, + }, + 'loggers': { + '': { # Logger root + 'handlers': ['console', 'file'], + 'level': 'DEBUG', + 'propagate': True + }, + '__name__': { + 'handlers': ['console'], + 'level': 'INFO', + 'propagate': False + } + } +} + +def setup_logging(name: str = '__main__'): + """Configures logging for the entire library.""" + logging.config.dictConfig(LOGGING_CONFIG) + # Get the logger for this specific module + logger = logging.getLogger(name) + return logger diff --git a/app_tools/formula_analysis.py b/app_tools/formula_analysis.py new file mode 100644 index 0000000..55ab737 --- /dev/null +++ b/app_tools/formula_analysis.py @@ -0,0 +1,237 @@ +import os, gc +import time +import argparse +from typing import List, Tuple + +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from ultralytics import YOLO +from unimernet.common.config import Config +import unimernet.tasks as tasks +from unimernet.processors import load_processor ## TODO: WARNING 'load_processor' is not declared in __all__ +from modules.post_process import get_croped_image, latex_rm_whitespace + +from app_tools.config import load_config, setup_logging + +logger = setup_logging('formula_analysis') + +class MathDataset(Dataset): + def __init__(self, image_paths, transform=None): + self.image_paths = image_paths + self.transform = transform + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + if idx >= len(self.image_paths) or idx < 0: + raise IndexError("Index out of range") + + # if not pil image, then convert to pil image + if isinstance(self.image_paths[idx], str): + try: + raw_image = Image.open(self.image_paths[idx]) + except IOError as e: + raise IOError(f"Error opening image: {self.image_paths[idx]}") from e + else: + raw_image = self.image_paths[idx] + + # apply transformation if any + if self.transform: + return self.transform(raw_image) + return raw_image + + +class FormulaProcessor: + """ + The FormulaProcessor class is designed to handle formula detection and recognition in images. + It is initialized with an optional configuration file and provides methods to detect and recognize formulas in images. + + """ + def __init__(self, config_path: str = None): + """ + It initializes the mfd_model, mfr_model, mfr_transform, latex_filling_list, and mf_image_list properties. + + Attributes: + - config: The configuration loaded from the config_path or default configuration. + - mfd_model: The model used for mfd transformation. + - mfr_model: The model used for mfr transformation. + - mfr_transform: The transformation object used for mfr transformation. + - latex_filling_list: A list to store latex filling data. + - mf_image_list: A list to store MF images. + + Methods: + - __init__: Initializes the developer object with the provided config_path or default configuration. + + Parameters: + - config_path (str): The path to the configuration file (optional). + + Note: + - The config_path parameter is optional. If not provided, the default configuration will be used. + - The load_config function is used internally to load the configuration from the provided path or default configuration. + - The _init_mfd_model, _init_mfr_model, and _init_mfr_transform methods are used internally to initialize the mfd_model, mfr_model, and mfr_transform properties respectively. + - The latex_filling_list and mf_image_list properties are empty lists to start with. + """ + self.config = load_config(config_path) if config_path else load_config() + self.mfd_model = self._init_mfd_model() + self.mfr_model, self.mfr_transform = self._init_mfr_model() + self.latex_filling_list = [] + self.mf_image_list = [] + + def _init_mfd_model(self): + """ + Initializes the MFD (Multiple Feature Detection) model. + + This method initializes the MFD model by setting the weight of the model from the configuration file and creating a new instance of the YOLO class. + + Returns: + mfd_model (YOLO): The initialized MFD model. + + """ + weight = self.config['model_args']['mfd_weight'] + mfd_model = YOLO(weight) + return mfd_model + + def _init_mfr_model(self) -> Tuple[torch.nn.Module, transforms.Compose]: + """ + Initializes the MFR model by loading the weights and setting the device. + + Returns: + Tuple: A tuple containing the MFR model (`torch.nn.Module`) and the transformation (`transforms.Compose`). + """ + weight_dir = self.config['model_args']['mfr_weight'] + device = self.config['model_args']['device'] + + args = argparse.Namespace(cfg_path="modules/UniMERNet/configs/demo.yaml", options=None) + cfg = Config(args) + cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin") + cfg.config.model.model_config.model_name = weight_dir + cfg.config.model.tokenizer_config.path = weight_dir + task = tasks.setup_task(cfg) + model = task.build_model(cfg) + model = model.to(device) + vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) + mfr_transform = transforms.Compose([vis_processor]) + return model, mfr_transform + + def detect_formulas(self, img_list: List, doc_layout_result: List[dict]) -> List[dict]: + """ + Detect formulas in the given list of images and update the document layout result with the detected formulas. + + Parameters: + - img_list: A list of images to detect formulas from. + - doc_layout_result: A list of dictionaries representing the document layout result. + Each dictionary contains the layout details of a single page. + + Returns: + - A list of dictionaries representing the updated document layout result with the detected formulas. + """ + img_size = self.config['model_args']['img_size'] + conf_thres = self.config['model_args']['conf_thres'] + iou_thres = self.config['model_args']['iou_thres'] + + logger.debug('Formula detection - init') + start = time.time() + + for idx, image in enumerate(img_list): + mfd_res = self.mfd_model.predict(image, imgsz=img_size, conf=conf_thres, iou=iou_thres, verbose=True)[0] + + for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): + xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] + new_item = { + 'category_id': 13 + int(cla.item()), + 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], + 'score': round(float(conf.item()), 2), + 'latex': '', + } + doc_layout_result[idx]['layout_dets'].append(new_item) + self.latex_filling_list.append(new_item) + bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax]) + self.mf_image_list.append(bbox_img) + + del mfd_res + torch.cuda.empty_cache() + gc.collect() + + logger.debug(f'Formula detection done in {round(time.time() - start, 2)}s!') + + return doc_layout_result + + def recognize_formulas(self, batch_size: int = 128): + """ + This method is used to recognize formulas in a batch of images. + + This method performs formula recognition by iterating over a dataset of images. + It uses a pre-trained model to generate predictions for each image in the dataset. + The recognized formulas are then stored in a list. + Finally, the method logs the number of formulas recognized and the time taken for formula recognition. + + The method takes an optional argument `batch_size` that specifies the number of images to process in each batch. + This can be useful for managing memory usage. The default batch size is 128. + + Parameters: + - batch_size: An integer specifying the batch size. Default is 128. + + Returns: + None + + Note: + - The method assumes that the pre-trained model and the image dataset have already been initialized and assigned to the appropriate instance variables. + """ + device = self.config['model_args']['device'] + + logger.debug('Formula recognition') + start = time.time() + + dataset = MathDataset(self.mf_image_list, transform=self.mfr_transform) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=32) + mfr_res = [] + for imgs in dataloader: + imgs = imgs.to(device) + output = self.mfr_model.generate({'image': imgs}) + mfr_res.extend(output['pred_str']) + for res, latex in zip(self.latex_filling_list, mfr_res): + res['latex'] = latex_rm_whitespace(latex) + + logger.info(f'Formula nums: {len(self.mf_image_list)} mfr time: {round(time.time() - start, 2)}') + + def detect_recognize_formulas(self, img_list: List, doc_layout_result: List[dict], batch_size: int = 128): + """ + Detect and recognize formulas in document layout results. + + This method takes a list of images, a list of document layout results, and an optional batch size. + It detects formulas in the document layout results by calling the `detect_formulas` method. + Then, it recognizes the detected formulas using the `recognize_formulas` method. + Finally, it returns the updated document layout results. + + Parameters: + - `img_list` (List): A list of images. + - `doc_layout_result` (List[dict]): A list of document layout results. + - `batch_size` (int, optional): The batch size for recognition. Defaults to 128. + + Returns: + - `doc_layout_result` (List[dict]): The updated document layout results. + + """ + doc_layout_result = self.detect_formulas(img_list, doc_layout_result) + self.recognize_formulas(batch_size) + return doc_layout_result + + def clear_memory(self): + """ + Clears the models from memory, freeing up resources. + """ + logger.info('Clearing models from memory.') + if self.mfd_model is not None: + del self.mfd_model + self.mfd_model = None + + if self.mfr_model is not None: + del self.mfr_model + self.mfr_model = None + + torch.cuda.empty_cache() + gc.collect() + logger.info('Models successfully cleared from memory.') diff --git a/app_tools/layout_analysis.py b/app_tools/layout_analysis.py new file mode 100644 index 0000000..eb9f920 --- /dev/null +++ b/app_tools/layout_analysis.py @@ -0,0 +1,136 @@ +import time +import gc +import torch +from typing import Optional + +from modules.layoutlmv3.model_init import Layoutlmv3_Predictor +from app_tools.config import setup_logging, load_config + + +class LayoutAnalyzer: + """ + class LayoutAnalyzer: + This class analyzes the layout of documents by detecting the layout of each page in a document image. + + Attributes: + logger: The logger object for logging debug, info, and error messages. + config: The configuration settings for the layout analysis. + model: The layout detection model. + + Methods: + __init__(self, config_path: Optional[str] = None) + Constructs a LayoutAnalyzer object. + + _init_model(self) -> Layoutlmv3_Predictor + Initializes the layout detection model. + + detect_layout(self, img_list: list) -> list + Detects the layout of multiple images. + + clear_model(self) + Clears the layout detection model from memory. + """ + def __init__(self, config_path: Optional[str] = None): + """ + Initializes an instance and init model. + + Args: + config_path (Optional[str]): The path to the configuration file. Defaults to None. + """ + self.logger = setup_logging('layout_analysis') + self.config = load_config(config_path) if config_path else load_config() + self.model = self._init_model() + + def _init_model(self) -> Layoutlmv3_Predictor: + """ + Initializes and returns an instance of the `Layoutlmv3_Predictor` class. + + Parameters: + - self: The current object. + + Returns: + A `Layoutlmv3_Predictor` object initialized with the specified `weight` value from the configuration. + + """ + weight = self.config['model_args']['layout_weight'] + model = Layoutlmv3_Predictor(weight) + return model + + def detect_layout(self, img_list: list) -> list: + """ + This method `detect_layout` is used to detect the layout of a list of images. + + Parameters: + - `img_list`: A list of images to detect the layout from. + + Returns: + - A list of layout results. + + Raises: + - `ValueError`: If the model is not initialized. Please call `init_model` before `detect_layout`. + + The method performs the following steps: + 1. It checks if the model is initialized. If not, it raises a `ValueError`. + 2. It initializes an empty list `doc_layout_result` to store the layout results. + 3. It logs a debug message indicating the start of layout detection. + 4. It starts a timer to measure the time taken for layout detection. + 5. It iterates over each image in the `img_list`. + a. It gets the height and width of the image. + b. It passes the image to the model for layout detection. + c. It adds additional information to the layout result, such as page number, height, and width. + d. It appends the layout result to the `doc_layout_result` list. + e. It deletes the layout result and clears the GPU memory. + 6. It logs a debug message indicating the completion of layout detection and the time taken. + + Example usage: + ``` + layout_detector = LayoutDetector() + layout_detector.init_model() + results = layout_detector.detect_layout([image1, image2, image3]) + ``` + """ + if self.model is None: + raise ValueError("Model is not initialized. Please call `init_model` before `detect_layout`.") + + doc_layout_result = [] + + self.logger.debug('Layout detection - init') + start = time.time() + + for idx, image in enumerate(img_list): + img_h, img_w = image.shape[0], image.shape[1] + + layout_res = self.model(image, ignore_catids=[]) + + layout_res['page_info'] = { + 'page_no': idx, + 'height': img_h, + 'width': img_w + } + doc_layout_result.append(layout_res) + + del layout_res + torch.cuda.empty_cache() + gc.collect() + + self.logger.debug(f'Layout detection done in {round(time.time() - start, 2)}s!') + + return doc_layout_result + + def clear_model(self): + """ + This method clears the model from memory by deleting the model object and freeing up GPU memory using torch.cuda.empty_cache(). + It also collects garbage to release any unreferenced memory. + + Example usage: + obj.clear_model() + """ + self.logger.info('Clearing the model from memory.') + + if self.model is not None: + del self.model + self.model = None + + torch.cuda.empty_cache() + gc.collect() + self.logger.info('Model successfully cleared from memory.') diff --git a/app_tools/ocr_analysis.py b/app_tools/ocr_analysis.py new file mode 100644 index 0000000..d6d23e9 --- /dev/null +++ b/app_tools/ocr_analysis.py @@ -0,0 +1,100 @@ +import time +import cv2 +import numpy as np +from PIL import Image + +from modules.self_modify import ModifiedPaddleOCR + +from app_tools.config import setup_logging + + +class OCRProcessor: + """ + This class represents an OCR Processor. + It is responsible for performing OCR recognition on a list of images based on certain conditions defined in the code. + + Attributes: + logger: Logger object for logging OCR analysis. + ocr_model: Instance of the ModifiedPaddleOCR class used for OCR recognition. + + Methods: + __init__(self, show_log: bool = True) + Initializes the OCRProcessor object with a logger and an instance of the ModifiedPaddleOCR class. + + recognize_ocr(self, img_list: list, doc_layout_result: list) -> list: + Performs OCR recognition on a list of images based on the given document layout results. + Returns a modified document layout result list with any newly recognized text appended to it. + """ + def __init__(self, show_log: bool = True): + """ + This class is responsible for initializing the OCR Analysis object. + + Attributes: + show_log (bool): A boolean value indicating whether to display log messages. Default is True. + """ + self.logger = setup_logging('ocr_analysis') + self.ocr_model = ModifiedPaddleOCR(show_log=show_log) + + def recognize_ocr(self, img_list: list, doc_layout_result: list) -> list: + """ + This method `recognize_ocr` performs Optical Character Recognition (OCR) on a list of images and appends the recognized text to the document layout result. + + Parameters: + - `img_list` (list): A list of images in numpy array format. + - `doc_layout_result` (list): A list containing the document layout results, each result representing a page in the document. + + Returns: + - `doc_layout_result` (list): The updated document layout result list with recognized text appended. + + 1. Converts each input image from RGB color space to BGR color space using OpenCV's `cv2.cvtColor` method. + 2. Iterates over each image and its corresponding layout details in the document layout output. + 3. For each layout detail, checks whether the category ID is 13 or 14, which correspond to formula categories. + If found, the bounding box coordinates of the layout detail are extracted and added to the `single_page_mfdetrec_res` list. + 4. Checks whether the category ID is one of [0, 1, 2, 4, 6, 7], which represent categories that require OCR. + If found, the bounding box coordinates are extracted, and a region of interest (ROI) is cropped from the image using the `pil_img.crop` method. This ROI image is converted back to BGR color space. + 5. The `self.ocr_model.ocr` method is called, passing the cropped image along with the `single_page_mfdetrec_res` list, to perform the OCR. + The OCR result is returned as a list of bounding boxes and their corresponding recognized text. + 6. If the OCR result is not empty, the method iterates over each bounding box and text pair in the result. + The four corner points of the bounding box are extracted, along with the confidence score and the recognized text. + A new layout detail with a category ID of 15 (corresponding to the recognized text) is created, and this detail is added to the `doc_layout_result`. + 7. The time taken for the OCR recognition is recorded and the updated `doc_layout_result` list is returned. + """ + self.logger.debug('OCR recognition - init') + start = time.time() + + for idx, image in enumerate(img_list): + pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + single_page_res = doc_layout_result[idx]['layout_dets'] + single_page_mfdetrec_res = [] + + for res in single_page_res: + if int(res['category_id']) in [13, 14]: # Categories for formula + xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) + xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) + single_page_mfdetrec_res.append({ + "bbox": [xmin, ymin, xmax, ymax], + }) + + for res in single_page_res: + if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # Categories that need OCR + xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) + xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) + crop_box = (xmin, ymin, xmax, ymax) + cropped_img = Image.new('RGB', pil_img.size, 'white') + cropped_img.paste(pil_img.crop(crop_box), crop_box) + cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR) + ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0] + if ocr_res: + for box_ocr_res in ocr_res: + p1, p2, p3, p4 = box_ocr_res[0] + text, score = box_ocr_res[1] + doc_layout_result[idx]['layout_dets'].append({ + 'category_id': 15, + 'poly': p1 + p2 + p3 + p4, + 'score': round(score, 2), + 'text': text, + }) + + self.logger.info(f'OCR recognition done in: {round(time.time() - start, 2)}s') + + return doc_layout_result diff --git a/app_tools/pdf.py b/app_tools/pdf.py new file mode 100644 index 0000000..9143749 --- /dev/null +++ b/app_tools/pdf.py @@ -0,0 +1,130 @@ +import os +from typing import List, Optional, Generator + +from modules.extract_pdf import load_pdf_fitz +from app_tools.config import load_config, setup_logging + + +class PDFProcessor: + """ + Class PDFTools + + This class provides a set of app_tools for working with PDF files. + + Methods: + - __init__(config_path: Optional[str] = None): Initializes the PDFTools object. + - load_config(config_path: Optional[str] = None) -> dict: Loads the configuration from a JSON file. + - setup_logging(name: str) -> Logger: Sets up the logging configuration. + + Attributes: + - config: A dictionary containing the configuration settings. + - dpi: The DPI (dots per inch) for the PDF files. + - logger: The logger object for logging messages. + + + __init__(config_path: Optional[str] = None) + Initializes the PDFTools object. + + Parameters: + config_path (Optional[str]): Path to the configuration file. If None, the default configuration file will be loaded. + + load_config(config_path: Optional[str] = None) -> dict + Loads the configuration from a JSON file. + + Parameters: + config_path (Optional[str]): Path to the configuration file. If None, the default configuration file will be loaded. + + Returns: + dict: A dictionary containing the configuration settings. + + Attributes: + - config (dict): A dictionary containing the configuration settings. + - dpi (int): The DPI (dots per inch) for the PDF files. + - logger (Logger): The logger object for logging messages. + """ + + def __init__(self, config_path: Optional[str] = None): + + self.config = load_config(config_path) if config_path else load_config() + self.dpi = self.config['model_args']['pdf_dpi'] + self.logger = setup_logging('pdf_tools') + + def check_pdf(self, pdf_path: str) -> List[str]: + """ + This method is used to check if a given file path is a directory or a single PDF file. + + Parameters: + - pdf_path (str): The file path to check. It can be either a directory or a single PDF file path. + + Returns: + - List[str]: A list of PDF file paths. + + Example Usage: + ``` + pdf_checker = PDFChecker() + result = pdf_checker.check_pdf('/path/to/pdfs') + print(result) + ``` + + Note: This method will return an empty list if no PDF files are found in the given directory or if the given file path is not a PDF file. + """ + if os.path.isdir(pdf_path): + all_pdfs = [os.path.join(pdf_path, name) for name in os.listdir(pdf_path) if name.endswith('.pdf')] + else: + all_pdfs = [pdf_path] + self.logger.info(f"Total files: {len(all_pdfs)}") + return all_pdfs + + def get_images(self, single_pdf: str) -> Optional[List[str]]: + """ + This method retrieves a list of images from a single PDF file. + + Parameters: + - single_pdf: A string representing the path to the PDF file. + + Returns: + - Optional[List[str]]: A list of strings representing the images extracted from the PDF file. Returns None if there was an error during the extraction process. + + Raises: + - None + + Example: + obj = MyClass() + images = obj.get_images('example.pdf') + """ + try: + img_list = load_pdf_fitz(single_pdf, self.dpi) + except Exception as e: + self.logger.error(f"Unexpected error with PDF file '{single_pdf}': {e}") + return None + return img_list + + def process_all_pdfs(self, all_pdfs: List[str]) -> Generator[tuple[int, str, List[str]], None, None]: + """ + This method `process_all_pdfs` processes a list of PDF files and returns a generator that yields a tuple for each PDF file. The tuple contains the index of the PDF file in the list, the path of the PDF file, and a list of image paths extracted from the PDF file. + + Parameters: + - `self`: The current instance of the class. + - `all_pdfs`: A List of strings representing the paths of the PDF files to be processed. + + Returns: + - `Generator[tuple[int, str, List[str]], None, None]`: A generator that yields a tuple for each PDF file. The tuple contains the index of the PDF file, the path of the PDF file, and a list of image paths extracted from the PDF file. + + Example usage: + ```python + pdf_processor = PDFProcessor() + pdf_files = ["file1.pdf", "file2.pdf", "file3.pdf"] + for index, path, images in pdf_processor.process_all_pdfs(pdf_files): + print(f"Processing PDF index: {index}") + print(f"PDF path: {path}") + print(f"Images: {images}") + ``` + """ + for idx, single_pdf in enumerate(all_pdfs): + img_list = self.get_images(single_pdf) + + if img_list is None: + continue + + self.logger.info(f"PDF index: {idx}, pages: {len(img_list)}") + yield idx, single_pdf, img_list \ No newline at end of file diff --git a/app_tools/table_analysis.py b/app_tools/table_analysis.py new file mode 100644 index 0000000..b712034 --- /dev/null +++ b/app_tools/table_analysis.py @@ -0,0 +1,134 @@ +import time +import torch +import gc +from PIL import Image +from struct_eqtable import build_model + +from app_tools.config import load_config, setup_logging + + +class TableProcessor: + """ + This class represents a Table Processor that is used for table recognition in documents. The `TableProcessor` class has the following methods: + """ + def __init__(self, config_path: str = None): + """ + Initializes a Table Processor object. + It takes an optional `config_path` parameter which specifies the path to a configuration file. + If no `config_path` is provided, the default configuration will be used. + This method also initializes a logger and loads the configuration. + It calls the `_init_tr_model` method to initialize the table recognition model. + + Attributes: + - logger: The logger instance for logging debug and error messages. + - config: The configuration object that stores the loaded configuration from the provided path. + - tr_model: The initialized text recognition model. + + Methods: + - __init__(self, config_path: str = None): Initializes an instance of the software with the provided configuration path. + - _init_tr_model(self): Initializes the text recognition model. + + Note: This class requires a logging setup function called 'setup_logging' to be defined and a config loading function called 'load_config' to be defined. + """ + self.logger = setup_logging('table_analysis') + self.config = load_config(config_path) if config_path else load_config() + self.tr_model = self._init_tr_model() + + def _init_tr_model(self): + """ + Initializes the translation model. + + This method initializes the translation model by setting the weight, maximum time, and device attributes based on the provided configuration. + It also builds the model using the `build_model` function. + + Returns: + - tr_model : object + The initialized translation model. + """ + weight = self.config['model_args']['tr_weight'] + max_time = self.config['model_args']['table_max_time'] + device = self.config['model_args']['device'] + + tr_model = build_model(weight, max_new_tokens=4096, max_time=max_time) + if device == 'cuda': + tr_model = tr_model.cuda() + return tr_model + + def recognize_tables(self, img_list: list, doc_layout_result: list) -> list: + """ + This method recognizes tables in a list of images based on the document layout results. + + Parameters: + - img_list: a list of images to perform table recognition on + - doc_layout_result: a list containing layout details of the document + + Returns: + - A modified version of doc_layout_result with table recognition results added + + The method initializes the table recognition process and sets the maximum time for the recognition. + It then iterates through each image in img_list and retrieves the layout details for that image from doc_layout_result. + + For each layout detail, if the category_id is 5 (indicating that it is a table), + the method crops the image based on the polygon coordinates of the layout detail and performs table recognition on the cropped image. + + The table recognition operation might take significant time, + so a timeout check is performed to determine if the recognition process exceeds the maximum time. + If it does, the timeout flag is set to True in the layout detail. + + The recognized LaTeX output is assigned to the "latex" property of the layout detail. + + Finally, the method logs the completion of the table recognition process and returns the modified doc_layout_result. + + Note: The method uses torch and PIL libraries for image processing and table recognition. + """ + max_time = self.config['model_args']['table_max_time'] + + self.logger.debug('Table recognition - init') + start_0 = time.time() + + for idx, image in enumerate(img_list): + pil_img = Image.fromarray(image) + single_page_res = doc_layout_result[idx]['layout_dets'] + + for jdx, res in enumerate(single_page_res): + if int(res['category_id']) == 5: # Perform table recognition + xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) + xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) + crop_box = (xmin, ymin, xmax, ymax) + cropped_img = pil_img.crop(crop_box) + + start = time.time() + with torch.no_grad(): + start_1 = time.time() + output = self.tr_model(cropped_img) # This operation might take significant time + self.logger.debug(f'{idx} - {jdx} tr_model generate in: {round(time.time() - start_1, 2)}s') + + if (time.time() - start) > max_time: + res["timeout"] = True + res["latex"] = output[0] + + self.logger.info(f'Table recognition done in: {round(time.time() - start_0, 2)}s') + + return doc_layout_result + + def clear_memory(self): + """ + Clears the table recognition model from memory. + + This method clears the table recognition model from memory by deleting the model object and releasing the memory occupied by the model. + It also clears the CUDA cache and performs garbage collection. + + Example: + clear_memory() + + """ + self.logger.info('Clearing the table recognition model from memory.') + + if self.tr_model is not None: + del self.tr_model + self.tr_model = None + + torch.cuda.empty_cache() + gc.collect() + self.logger.info('Table recognition model successfully cleared from memory.') + diff --git a/app_tools/utils.py b/app_tools/utils.py new file mode 100644 index 0000000..e93410e --- /dev/null +++ b/app_tools/utils.py @@ -0,0 +1,21 @@ +import os +import json + + +def save_file(output_dir, single_pdf, doc_layout_result): + """ + This function saves the document layout result as a JSON file in a specified output directory. + + Parameters: + - output_dir (str): The directory where the JSON file will be saved. + - single_pdf (str): The path to the single PDF file. + - doc_layout_result (dict): The document layout result that will be saved as a JSON file. + + Returns: + - basename (str): The basename of the single PDF file. + """ + os.makedirs(output_dir, exist_ok=True) + basename = os.path.basename(single_pdf)[0:-4] + with open(os.path.join(output_dir, f'{basename}.json'), 'w') as f: + json.dump(doc_layout_result, f) + return basename diff --git a/app_tools/visualize.py b/app_tools/visualize.py new file mode 100644 index 0000000..2428afc --- /dev/null +++ b/app_tools/visualize.py @@ -0,0 +1,107 @@ +import os +import shutil +import cv2 +from PIL import Image, ImageDraw, ImageFont + +from modules.latex2png import tex2pil, zhtext2pil +from app_tools.config import setup_logging + +# Apply the logging configuration +logger = setup_logging('visualize') + +color_palette = [ + (255, 64, 255), (255, 255, 0), (0, 255, 255), (255, 215, 135), (215, 0, 95), (100, 0, 48), (0, 175, 0), + (95, 0, 95), (175, 95, 0), (95, 95, 0), + (95, 95, 255), (95, 175, 135), (215, 95, 0), (0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 95, 215), + (0, 0, 0), (0, 0, 0), (0, 0, 0) +] +id2names = ["title", "plain_text", "abandon", "figure", "figure_caption", "table", "table_caption", + "table_footnote", + "isolate_formula", "formula_caption", " ", " ", " ", "inline_formula", "isolated_formula", + "ocr_text"] + +def get_visualize(img_list: list, doc_layout_result: list, render: bool, output_dir: str, basename: str): + """ + This function takes a list of images, the result of a document layout analysis, a boolean flag 'render', an output directory path, and a basename as input arguments. + It generates visualizations of the document layout and saves them as a PDF file. + + Parameters: + - img_list (list): A list of images. Each image should be a numpy array representing an image. + - doc_layout_result: The result of a document layout analysis. It should be a list of dictionaries, + where each dictionary represents the layout details of a single page. + Each dictionary should contain information such as the category ID, polygon coordinates, + and text/latex content. + - render (bool): A boolean flag indicating whether to render the text/latex content in the visualizations. + - output_dir: The output directory where the PDF file will be saved. + - basename: The basename of the PDF file. + + Returns: + None + + Example Usage: + img_list = [image1, image2, ...] # list of images + doc_layout_result = [...] # list of layout dictionaries + render = True # or False + output_dir = '/path/to/output/directory' + basename = 'output' + get_visualize(img_list, doc_layout_result, render, output_dir, basename) + """ + vis_pdf_result = [] + + for idx, image in enumerate(img_list): + single_page_res = doc_layout_result[idx]['layout_dets'] + + if render: + vis_img = Image.new('RGB', Image.fromarray(image).size, 'white') + else: + vis_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + draw = ImageDraw.Draw(vis_img) + + for res in single_page_res: + label = int(res['category_id']) + if label > 15: # categories that do not need to visualize + continue + label_name = id2names[label] + x_min, y_min = int(res['poly'][0]), int(res['poly'][1]) + x_max, y_max = int(res['poly'][4]), int(res['poly'][5]) + if render and label in [13, 14, 15]: + try: + if label in [13, 14]: # render formula + window_img = tex2pil(res['latex'])[0] + else: + window_img = zhtext2pil(res['text']) + # This code is unreachable + # if True: # render chinese + # window_img = zhtext2pil(res['text']) + # else: # render english + # window_img = tex2pil([res['text']], tex_type="text")[0] + ratio = min((x_max - x_min) / window_img.width, (y_max - y_min) / window_img.height) - 0.05 + window_img = window_img.resize( + (int(window_img.width * ratio), int(window_img.height * ratio))) + vis_img.paste(window_img, (int(x_min + (x_max - x_min - window_img.width) / 2), + int(y_min + (y_max - y_min - window_img.height) / 2))) + except Exception as e: + logger.error(f"got exception on {res['text']}, error info: {e}") + + draw.rectangle((x_min, y_min, x_max, y_max), fill=None, outline=color_palette[label], width=1) + font_text = ImageFont.truetype("assets/fonts/simhei.ttf", 15, encoding="utf-8") + draw.text((x_min, y_min), label_name, color_palette[label], font=font_text) + + width, height = vis_img.size + width, height = int(0.75 * width), int(0.75 * height) + vis_img = vis_img.resize((width, height)) + vis_pdf_result.append(vis_img) + + first_page = vis_pdf_result.pop(0) + first_page.save( + fp=os.path.join(output_dir, f'{basename}.pdf'), + format='PDF', + resolution=100, + save_all=True, + append_images=vis_pdf_result + ) + try: + shutil.rmtree('./temp') + except Exception as e: + logger.error(f"got exception on shutil.rmtree, error info: {e}") + pass diff --git a/assets/examples/example.pdf b/assets/examples/example.pdf index 19aa73b..2f158e5 100644 Binary files a/assets/examples/example.pdf and b/assets/examples/example.pdf differ diff --git a/requirements.txt b/requirements.txt index b6a98a5..3593b7b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,15 @@ -unimernet -matplotlib -PyMuPDF -ultralytics -paddlepaddle-gpu +unimernet==0.1.6 +matplotlib==3.9.2 +PyMuPDF==1.24.9 +ultralytics==8.2.86 +paddlepaddle-gpu==2.6.1 paddleocr==2.7.3 -struct-eqtable==0.1.0 \ No newline at end of file +struct-eqtable==0.1.0 + +torch==2.3.1 +torchvision==0.18.1 +numpy==1.26.4 +opencv-python==4.6.0.66 +Pillow==8.4.0 +PyYAML==6.0.2 +pytz==2024.1 \ No newline at end of file