Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
output/*
data/*
temp*
test*
test-magic-pdf.py

# python
.ipynb_checkpoints
Expand Down
58 changes: 58 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# refactoring pdf_extract.py
import time
import argparse

from app_tools.config import setup_logging
from app_tools.pdf import PDFProcessor
from app_tools.layout_analysis import LayoutAnalyzer
from app_tools.formula_analysis import FormulaProcessor
from app_tools.ocr_analysis import OCRProcessor
from app_tools.table_analysis import TableProcessor
from app_tools.visualize import get_visualize
from app_tools.utils import save_file

logger = setup_logging('app')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Process PDF files and render output images.")
parser.add_argument('--pdf', type=str, required=True, help="Path to the input PDF file")
parser.add_argument('--output', type=str, default="output", help="Output directory or filename prefix (default: 'output')")
parser.add_argument('--batch-size', type=int, default=128, help="Batch size for processing (default: 128)")
parser.add_argument('--vis', action='store_true', help="Enable visualization mode")
parser.add_argument('--render', action='store_true', help="Enable rendering mode")
args = parser.parse_args()
logger.info("Arguments: %s", args)

logger.info('Started!')
start = time.time()
## ======== model init ========##
analyzer = LayoutAnalyzer()
formulas = FormulaProcessor()
ocr_processor = OCRProcessor(show_log=True)
table_processor = TableProcessor()
logger.info(f'Model init done in {int(time.time() - start)}s!')
## ======== model init ========##

start = time.time()
pdf_processor = PDFProcessor()
all_pdfs = pdf_processor.check_pdf(args.pdf)

for idx, single_pdf, img_list in pdf_processor.process_all_pdfs(all_pdfs):

doc_layout_result = analyzer.detect_layout(img_list)
doc_layout_result = formulas.detect_recognize_formulas(img_list, doc_layout_result, args.batch_size)
doc_layout_result = ocr_processor.recognize_ocr(img_list, doc_layout_result)
doc_layout_result = table_processor.recognize_tables(img_list, doc_layout_result)

basename = save_file(args.output, single_pdf, doc_layout_result)
logger.debug(f'Save file: {basename}.json')

if args.vis:
logger.info("Visualization mode enabled")
get_visualize(img_list, doc_layout_result, args.render, args.output, basename)
else:
logger.info("Visualization mode disabled")

logger.info(f'Finished! time cost: {int(time.time() - start)} s')
logger.info('----------------------------------------')
Empty file added app_tools/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions app_tools/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import yaml
import pytz
import logging
import logging.config
from datetime import datetime

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
model_configs_path = os.path.join(parent_dir, 'configs/model_configs.yaml')

################### MODEL CONFIGS ###################

def load_config():
with open(model_configs_path) as f:
model_configs = yaml.load(f, Loader=yaml.FullLoader)
return model_configs

################### LOGGING CONFIG ###################

# TODO: Define a suitable log_file_path
log_file_path = os.path.join(parent_dir, 'app_logs.log')

# TODO: Add to config file
TIMEZONE: str = "Asia/Shanghai" # 'Europe/Madrid'
timezone = pytz.timezone(TIMEZONE) # Specify your time zone here


class CustomFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None, tz=None):
super().__init__(fmt=fmt, datefmt=datefmt)
self.tz = tz

def formatTime(self, record, datefmt=None):
dt = datetime.fromtimestamp(record.created, self.tz)
if datefmt:
s = dt.strftime(datefmt)
else:
try:
s = dt.isoformat(timespec='milliseconds')
except TypeError:
s = dt.isoformat()
return s

# Basic logging configuration
LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'()': CustomFormatter,
'format': '%(asctime)s - %(name)s - [%(levelname)s] - %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S',
'tz': timezone
},
},
'handlers': {
'console': {
'level': 'DEBUG',
'class': 'logging.StreamHandler',
'formatter': 'standard'
},
'file': {
'level': 'INFO',
'class': 'logging.FileHandler',
'filename': log_file_path,
'formatter': 'standard'
},
},
'loggers': {
'': { # Logger root
'handlers': ['console', 'file'],
'level': 'DEBUG',
'propagate': True
},
'__name__': {
'handlers': ['console'],
'level': 'INFO',
'propagate': False
}
}
}

def setup_logging(name: str = '__main__'):
"""Configures logging for the entire library."""
logging.config.dictConfig(LOGGING_CONFIG)
# Get the logger for this specific module
logger = logging.getLogger(name)
return logger
237 changes: 237 additions & 0 deletions app_tools/formula_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import os, gc
import time
import argparse
from typing import List, Tuple

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from ultralytics import YOLO
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor ## TODO: WARNING 'load_processor' is not declared in __all__
from modules.post_process import get_croped_image, latex_rm_whitespace

from app_tools.config import load_config, setup_logging

logger = setup_logging('formula_analysis')

class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
if idx >= len(self.image_paths) or idx < 0:
raise IndexError("Index out of range")

# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
try:
raw_image = Image.open(self.image_paths[idx])
except IOError as e:
raise IOError(f"Error opening image: {self.image_paths[idx]}") from e
else:
raw_image = self.image_paths[idx]

# apply transformation if any
if self.transform:
return self.transform(raw_image)
return raw_image


class FormulaProcessor:
"""
The FormulaProcessor class is designed to handle formula detection and recognition in images.
It is initialized with an optional configuration file and provides methods to detect and recognize formulas in images.

"""
def __init__(self, config_path: str = None):
"""
It initializes the mfd_model, mfr_model, mfr_transform, latex_filling_list, and mf_image_list properties.

Attributes:
- config: The configuration loaded from the config_path or default configuration.
- mfd_model: The model used for mfd transformation.
- mfr_model: The model used for mfr transformation.
- mfr_transform: The transformation object used for mfr transformation.
- latex_filling_list: A list to store latex filling data.
- mf_image_list: A list to store MF images.

Methods:
- __init__: Initializes the developer object with the provided config_path or default configuration.

Parameters:
- config_path (str): The path to the configuration file (optional).

Note:
- The config_path parameter is optional. If not provided, the default configuration will be used.
- The load_config function is used internally to load the configuration from the provided path or default configuration.
- The _init_mfd_model, _init_mfr_model, and _init_mfr_transform methods are used internally to initialize the mfd_model, mfr_model, and mfr_transform properties respectively.
- The latex_filling_list and mf_image_list properties are empty lists to start with.
"""
self.config = load_config(config_path) if config_path else load_config()
self.mfd_model = self._init_mfd_model()
self.mfr_model, self.mfr_transform = self._init_mfr_model()
self.latex_filling_list = []
self.mf_image_list = []

def _init_mfd_model(self):
"""
Initializes the MFD (Multiple Feature Detection) model.

This method initializes the MFD model by setting the weight of the model from the configuration file and creating a new instance of the YOLO class.

Returns:
mfd_model (YOLO): The initialized MFD model.

"""
weight = self.config['model_args']['mfd_weight']
mfd_model = YOLO(weight)
return mfd_model

def _init_mfr_model(self) -> Tuple[torch.nn.Module, transforms.Compose]:
"""
Initializes the MFR model by loading the weights and setting the device.

Returns:
Tuple: A tuple containing the MFR model (`torch.nn.Module`) and the transformation (`transforms.Compose`).
"""
weight_dir = self.config['model_args']['mfr_weight']
device = self.config['model_args']['device']

args = argparse.Namespace(cfg_path="modules/UniMERNet/configs/demo.yaml", options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
model = model.to(device)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
mfr_transform = transforms.Compose([vis_processor])
return model, mfr_transform

def detect_formulas(self, img_list: List, doc_layout_result: List[dict]) -> List[dict]:
"""
Detect formulas in the given list of images and update the document layout result with the detected formulas.

Parameters:
- img_list: A list of images to detect formulas from.
- doc_layout_result: A list of dictionaries representing the document layout result.
Each dictionary contains the layout details of a single page.

Returns:
- A list of dictionaries representing the updated document layout result with the detected formulas.
"""
img_size = self.config['model_args']['img_size']
conf_thres = self.config['model_args']['conf_thres']
iou_thres = self.config['model_args']['iou_thres']

logger.debug('Formula detection - init')
start = time.time()

for idx, image in enumerate(img_list):
mfd_res = self.mfd_model.predict(image, imgsz=img_size, conf=conf_thres, iou=iou_thres, verbose=True)[0]

for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
doc_layout_result[idx]['layout_dets'].append(new_item)
self.latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
self.mf_image_list.append(bbox_img)

del mfd_res
torch.cuda.empty_cache()
gc.collect()

logger.debug(f'Formula detection done in {round(time.time() - start, 2)}s!')

return doc_layout_result

def recognize_formulas(self, batch_size: int = 128):
"""
This method is used to recognize formulas in a batch of images.

This method performs formula recognition by iterating over a dataset of images.
It uses a pre-trained model to generate predictions for each image in the dataset.
The recognized formulas are then stored in a list.
Finally, the method logs the number of formulas recognized and the time taken for formula recognition.

The method takes an optional argument `batch_size` that specifies the number of images to process in each batch.
This can be useful for managing memory usage. The default batch size is 128.

Parameters:
- batch_size: An integer specifying the batch size. Default is 128.

Returns:
None

Note:
- The method assumes that the pre-trained model and the image dataset have already been initialized and assigned to the appropriate instance variables.
"""
device = self.config['model_args']['device']

logger.debug('Formula recognition')
start = time.time()

dataset = MathDataset(self.mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=32)
mfr_res = []
for imgs in dataloader:
imgs = imgs.to(device)
output = self.mfr_model.generate({'image': imgs})
mfr_res.extend(output['pred_str'])
for res, latex in zip(self.latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)

logger.info(f'Formula nums: {len(self.mf_image_list)} mfr time: {round(time.time() - start, 2)}')

def detect_recognize_formulas(self, img_list: List, doc_layout_result: List[dict], batch_size: int = 128):
"""
Detect and recognize formulas in document layout results.

This method takes a list of images, a list of document layout results, and an optional batch size.
It detects formulas in the document layout results by calling the `detect_formulas` method.
Then, it recognizes the detected formulas using the `recognize_formulas` method.
Finally, it returns the updated document layout results.

Parameters:
- `img_list` (List): A list of images.
- `doc_layout_result` (List[dict]): A list of document layout results.
- `batch_size` (int, optional): The batch size for recognition. Defaults to 128.

Returns:
- `doc_layout_result` (List[dict]): The updated document layout results.

"""
doc_layout_result = self.detect_formulas(img_list, doc_layout_result)
self.recognize_formulas(batch_size)
return doc_layout_result

def clear_memory(self):
"""
Clears the models from memory, freeing up resources.
"""
logger.info('Clearing models from memory.')
if self.mfd_model is not None:
del self.mfd_model
self.mfd_model = None

if self.mfr_model is not None:
del self.mfr_model
self.mfr_model = None

torch.cuda.empty_cache()
gc.collect()
logger.info('Models successfully cleared from memory.')
Loading