Skip to content

Commit 9b73ff1

Browse files
authored
[feat] Dynamic import based on the available dependencies (#65)
* Introducing file_utils and improve init * Remove the import class within OCR modules * fix paddle check in file utils * Improve layoutmodel specs * More robust import for ocr agents
1 parent e8d5488 commit 9b73ff1

File tree

9 files changed

+309
-125
lines changed

9 files changed

+309
-125
lines changed

src/layoutparser/__init__.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,63 @@
11
__version__ = "0.2.0"
22

3-
from .elements import (
4-
Interval, Rectangle, Quadrilateral,
5-
TextBlock, Layout
6-
)
3+
import sys
74

8-
from .visualization import (
9-
draw_box, draw_text
5+
from .file_utils import (
6+
_LazyModule,
7+
is_detectron2_available,
8+
is_paddle_available,
9+
is_pytesseract_available,
10+
is_gcv_available,
1011
)
1112

12-
from .ocr import (
13-
GCVFeatureType, GCVAgent,
14-
TesseractFeatureType, TesseractAgent
15-
)
13+
_import_structure = {
14+
"elements": [
15+
"Interval",
16+
"Rectangle",
17+
"Quadrilateral",
18+
"TextBlock",
19+
"Layout"
20+
],
21+
"visualization": [
22+
"draw_box",
23+
"draw_text"
24+
],
25+
"io": [
26+
"load_json",
27+
"load_dict",
28+
"load_csv",
29+
"load_dataframe"
30+
],
31+
"file_utils":[
32+
"is_torch_available",
33+
"is_torch_cuda_available",
34+
"is_detectron2_available",
35+
"is_paddle_available",
36+
"is_pytesseract_available",
37+
"is_gcv_available",
38+
"requires_backends"
39+
]
40+
}
1641

17-
from .models import (
18-
Detectron2LayoutModel,
19-
PaddleDetectionLayoutModel
20-
)
42+
if is_detectron2_available():
43+
_import_structure["models.detectron2"] = ["Detectron2LayoutModel"]
44+
45+
if is_paddle_available():
46+
_import_structure["models.paddledetection"] = ["PaddleDetectionLayoutModel"]
2147

22-
from .io import (
23-
load_json,
24-
load_dict,
25-
load_csv,
26-
load_dataframe
27-
)
48+
if is_pytesseract_available():
49+
_import_structure["ocr.tesseract_agent"] = [
50+
"TesseractAgent",
51+
"TesseractFeatureType",
52+
]
53+
54+
if is_gcv_available():
55+
_import_structure["ocr.gcv_agent"] = ["GCVAgent", "GCVFeatureType"]
56+
57+
sys.modules[__name__] = _LazyModule(
58+
__name__,
59+
globals()["__file__"],
60+
_import_structure,
61+
module_spec=__spec__,
62+
extra_objects={"__version__": __version__},
63+
)

src/layoutparser/file_utils.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Some code are adapted from
2+
# https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
3+
4+
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
5+
import sys
6+
import os
7+
import logging
8+
import importlib.util
9+
from types import ModuleType
10+
11+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
12+
13+
# The package importlib_metadata is in a different place, depending on the python version.
14+
if sys.version_info < (3, 8):
15+
import importlib_metadata
16+
else:
17+
import importlib.metadata as importlib_metadata
18+
19+
###########################################
20+
############ Layout Model Deps ############
21+
###########################################
22+
23+
_torch_available = importlib.util.find_spec("torch") is not None
24+
try:
25+
_torch_version = importlib_metadata.version("torch")
26+
logger.debug(f"PyTorch version {_torch_version} available.")
27+
except importlib_metadata.PackageNotFoundError:
28+
_torch_available = False
29+
30+
_detectron2_available = importlib.util.find_spec("detectron2") is not None
31+
try:
32+
_detectron2_version = importlib_metadata.version("detectron2")
33+
logger.debug(f"Detectron2 version {_detectron2_version} available")
34+
except importlib_metadata.PackageNotFoundError:
35+
_detectron2_available = False
36+
37+
_paddle_available = importlib.util.find_spec("paddle") is not None
38+
try:
39+
# The name of the paddlepaddle library:
40+
# Install name: pip install paddlepaddle
41+
# Import name: import paddle
42+
_paddle_version = importlib_metadata.version("paddlepaddle")
43+
logger.debug(f"Paddle version {_paddle_version} available.")
44+
except importlib_metadata.PackageNotFoundError:
45+
_paddle_available = False
46+
47+
###########################################
48+
############## OCR Tool Deps ##############
49+
###########################################
50+
51+
_pytesseract_available = importlib.util.find_spec("pytesseract") is not None
52+
try:
53+
_pytesseract_version = importlib_metadata.version("pytesseract")
54+
logger.debug(f"Pytesseract version {_pytesseract_version} available.")
55+
except importlib_metadata.PackageNotFoundError:
56+
_pytesseract_available = False
57+
58+
_gcv_available = importlib.util.find_spec("google.cloud.vision") is not None
59+
try:
60+
_gcv_version = importlib_metadata.version(
61+
"google-cloud-vision"
62+
) # This is slightly different
63+
logger.debug(f"Google Cloud Vision Utils version {_gcv_version} available.")
64+
except importlib_metadata.PackageNotFoundError:
65+
_gcv_available = False
66+
67+
68+
def is_torch_available():
69+
return _torch_available
70+
71+
72+
def is_torch_cuda_available():
73+
if is_torch_available():
74+
import torch
75+
76+
return torch.cuda.is_available()
77+
else:
78+
return False
79+
80+
81+
def is_paddle_available():
82+
return _paddle_available
83+
84+
85+
def is_detectron2_available():
86+
return _detectron2_available
87+
88+
89+
def is_pytesseract_available():
90+
return _pytesseract_available
91+
92+
93+
def is_gcv_available():
94+
return _gcv_available
95+
96+
97+
PYTORCH_IMPORT_ERROR = """
98+
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
99+
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
100+
"""
101+
102+
DETECTRON2_IMPORT_ERROR = """
103+
{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
104+
installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
105+
that match your environment. Typically the following would work for MacOS or Linux CPU machines:
106+
pip install 'git+https://github.com/facebookresearch/[email protected]#egg=detectron2'
107+
"""
108+
109+
PADDLE_IMPORT_ERROR = """
110+
{0} requires the PaddlePaddle library but it was not found in your environment. Checkout the instructions on the
111+
installation page: https://github.com/PaddlePaddle/Paddle and follow the ones that match your environment.
112+
"""
113+
114+
PYTESSERACT_IMPORT_ERROR = """
115+
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
116+
`pip install pytesseract`
117+
"""
118+
119+
GCV_IMPORT_ERROR = """
120+
{0} requires the Google Cloud Vision Python utils but it was not found in your environment. You can install it with pip:
121+
`pip install google-cloud-vision==1`
122+
"""
123+
124+
BACKENDS_MAPPING = dict(
125+
[
126+
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
127+
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
128+
("paddle", (is_paddle_available, PADDLE_IMPORT_ERROR)),
129+
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
130+
("google-cloud-vision", (is_gcv_available, GCV_IMPORT_ERROR)),
131+
]
132+
)
133+
134+
135+
def requires_backends(obj, backends):
136+
if not isinstance(backends, (list, tuple)):
137+
backends = [backends]
138+
139+
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
140+
if not all(BACKENDS_MAPPING[backend][0]() for backend in backends):
141+
raise ImportError(
142+
"".join([BACKENDS_MAPPING[backend][1].format(name) for backend in backends])
143+
)
144+
145+
146+
class _LazyModule(ModuleType):
147+
"""
148+
Module class that surfaces all objects but only performs associated imports when the objects are requested.
149+
"""
150+
151+
# Adapted from HuggingFace
152+
# https://github.com/huggingface/transformers/blob/c37573806ab3526dd805c49cbe2489ad4d68a9d7/src/transformers/file_utils.py#L1990
153+
154+
def __init__(
155+
self, name, module_file, import_structure, module_spec=None, extra_objects=None
156+
):
157+
super().__init__(name)
158+
self._modules = set(import_structure.keys())
159+
self._class_to_module = {}
160+
for key, values in import_structure.items():
161+
for value in values:
162+
self._class_to_module[value] = key
163+
# Needed for autocompletion in an IDE
164+
self.__all__ = list(import_structure.keys()) + sum(
165+
import_structure.values(), []
166+
)
167+
self.__file__ = module_file
168+
self.__spec__ = module_spec
169+
self.__path__ = [os.path.dirname(module_file)]
170+
self._objects = {} if extra_objects is None else extra_objects
171+
self._name = name
172+
self._import_structure = import_structure
173+
174+
# Following [PEP 366](https://www.python.org/dev/peps/pep-0366/)
175+
# The __package__ variable should be set
176+
# https://docs.python.org/3/reference/import.html#__package__
177+
self.__package__ = self.__name__
178+
179+
# Needed for autocompletion in an IDE
180+
def __dir__(self):
181+
return super().__dir__() + self.__all__
182+
183+
def __getattr__(self, name: str) -> Any:
184+
if name in self._objects:
185+
return self._objects[name]
186+
if name in self._modules:
187+
value = self._get_module(name)
188+
elif name in self._class_to_module.keys():
189+
module = self._get_module(self._class_to_module[name])
190+
value = getattr(module, name)
191+
else:
192+
raise AttributeError(f"module {self.__name__} has no attribute {name}")
193+
194+
setattr(self, name, value)
195+
return value
196+
197+
def _get_module(self, module_name: str):
198+
return importlib.import_module("." + module_name, self.__name__)
199+
200+
def __reduce__(self):
201+
return (self.__class__, (self._name, self.__file__, self._import_structure))
Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
2-
import os
3-
import importlib
2+
3+
from ..file_utils import requires_backends
44

55

66
class BaseLayoutModel(ABC):
@@ -23,28 +23,7 @@ def DEPENDENCIES(self):
2323
"""DEPENDENCIES lists all necessary dependencies for the class."""
2424
pass
2525

26-
@property
27-
@abstractmethod
28-
def MODULES(self):
29-
"""MODULES instructs how to import these necessary libraries."""
30-
pass
31-
32-
@classmethod
33-
def _import_module(cls):
34-
for m in cls.MODULES:
35-
if importlib.util.find_spec(m["module_path"]):
36-
setattr(
37-
cls, m["import_name"], importlib.import_module(m["module_path"])
38-
)
39-
else:
40-
raise ModuleNotFoundError(
41-
f"\n "
42-
f"\nPlease install the following libraries to support the class {cls.__name__}:"
43-
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
44-
f"\n "
45-
)
46-
4726
def __new__(cls, *args, **kwargs):
4827

49-
cls._import_module()
28+
requires_backends(cls, cls.DEPENDENCIES)
5029
return super().__new__(cls)

src/layoutparser/models/detectron2/layoutmodel.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from PIL import Image
22
import numpy as np
3-
import torch
43

54
from .catalog import PathManager, LABEL_MAP_CATALOG
65
from ..base_layoutmodel import BaseLayoutModel
76
from ...elements import Rectangle, TextBlock, Layout
7+
from ...file_utils import is_torch_cuda_available, is_detectron2_available
8+
9+
if is_detectron2_available():
10+
import detectron2.engine
11+
import detectron2.config
12+
813

914
__all__ = ["Detectron2LayoutModel"]
1015

@@ -42,13 +47,6 @@ class Detectron2LayoutModel(BaseLayoutModel):
4247
"""
4348

4449
DEPENDENCIES = ["detectron2"]
45-
MODULES = [
46-
{
47-
"import_name": "_engine",
48-
"module_path": "detectron2.engine",
49-
},
50-
{"import_name": "_config", "module_path": "detectron2.config"},
51-
]
5250
DETECTOR_NAME = "detectron2"
5351

5452
def __init__(
@@ -70,7 +68,7 @@ def __init__(
7068
if enforce_cpu:
7169
extra_config.extend(["MODEL.DEVICE", "cpu"])
7270

73-
cfg = self._config.get_cfg()
71+
cfg = detectron2.config.get_cfg()
7472
config_path = self._reconstruct_path_with_detector_name(config_path)
7573
config_path = PathManager.get_local_path(config_path)
7674
cfg.merge_from_file(config_path)
@@ -79,7 +77,10 @@ def __init__(
7977
if model_path is not None:
8078
model_path = self._reconstruct_path_with_detector_name(model_path)
8179
cfg.MODEL.WEIGHTS = model_path
82-
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
80+
81+
if not enforce_cpu:
82+
cfg.MODEL.DEVICE = "cuda" if is_torch_cuda_available() else "cpu"
83+
8384
self.cfg = cfg
8485

8586
self.label_map = label_map
@@ -135,7 +136,7 @@ def gather_output(self, outputs):
135136
return layout
136137

137138
def _create_model(self):
138-
self.model = self._engine.DefaultPredictor(self.cfg)
139+
self.model = detectron2.engine.DefaultPredictor(self.cfg)
139140

140141
def detect(self, image):
141142
"""Detect the layout of a given image.

0 commit comments

Comments
 (0)