Skip to content

Commit 60807a7

Browse files
feat: New document figure classifier model (#73)
Signed-off-by: Matteo Omenetti <[email protected]> Signed-off-by: Nikos Livathinos <[email protected]> Co-authored-by: Nikos Livathinos <[email protected]>
1 parent 71e4c2f commit 60807a7

File tree

5 files changed

+386
-0
lines changed

5 files changed

+386
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
import argparse
6+
import logging
7+
import os
8+
import sys
9+
import time
10+
from pathlib import Path
11+
12+
from huggingface_hub import snapshot_download
13+
from PIL import Image
14+
15+
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import DocumentFigureClassifierPredictor
16+
17+
18+
def demo(
19+
logger: logging.Logger,
20+
artifact_path: str,
21+
device: str,
22+
num_threads: int,
23+
image_dir: str,
24+
viz_dir: str,
25+
):
26+
r"""
27+
Apply DocumentFigureClassifierPredictor on the input image directory
28+
"""
29+
# Create the layout predictor
30+
document_figure_classifier_predictor = DocumentFigureClassifierPredictor(artifact_path, device=device, num_threads=num_threads)
31+
32+
image_dir = Path(image_dir)
33+
images = []
34+
image_names = os.listdir(image_dir)
35+
image_names.sort()
36+
for image_name in image_names:
37+
image = Image.open(image_dir / image_name)
38+
images.append(image)
39+
40+
t0 = time.perf_counter()
41+
outputs = document_figure_classifier_predictor.predict(images)
42+
total_ms = 1000 * (time.perf_counter() - t0)
43+
avg_ms = (total_ms / len(image_names)) if len(image_names) > 0 else 0
44+
logger.info(
45+
"For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format(
46+
len(image_names), total_ms, avg_ms
47+
)
48+
)
49+
50+
for i, output in enumerate(outputs):
51+
image_name = image_names[i]
52+
logger.info(f"Predictions for: '{image_name}':")
53+
for pred in output:
54+
logger.info(f" Class '{pred[0]}' has probability {pred[1]}")
55+
56+
57+
def main(args):
58+
num_threads = int(args.num_threads) if args.num_threads is not None else None
59+
device = args.device.lower()
60+
image_dir = args.image_dir
61+
viz_dir = args.viz_dir
62+
63+
# Initialize logger
64+
logging.basicConfig(level=logging.DEBUG)
65+
logger = logging.getLogger("DocumentFigureClassifierPredictor")
66+
logger.setLevel(logging.DEBUG)
67+
if not logger.hasHandlers():
68+
handler = logging.StreamHandler(sys.stdout)
69+
formatter = logging.Formatter(
70+
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
71+
)
72+
handler.setFormatter(formatter)
73+
logger.addHandler(handler)
74+
75+
# Ensure the viz dir
76+
Path(viz_dir).mkdir(parents=True, exist_ok=True)
77+
78+
# Download models from HF
79+
download_path = snapshot_download(repo_id="ds4sd/DocumentFigureClassifier", revision="v1.0.0")
80+
81+
# Test the figure classifier model
82+
demo(logger, download_path, device, num_threads, image_dir, viz_dir)
83+
84+
85+
if __name__ == "__main__":
86+
r"""
87+
python -m demo.demo_document_figure_classifier_predictor -i <images_dir>
88+
"""
89+
parser = argparse.ArgumentParser(description="Test the DocumentFigureClassifierPredictor")
90+
parser.add_argument(
91+
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]"
92+
)
93+
parser.add_argument(
94+
"-n", "--num_threads", required=False, default=4, help="Number of threads"
95+
)
96+
parser.add_argument(
97+
"-i",
98+
"--image_dir",
99+
required=True,
100+
help="PNG images input directory",
101+
)
102+
parser.add_argument(
103+
"-v",
104+
"--viz_dir",
105+
required=False,
106+
default="viz/",
107+
help="Directory to save prediction visualizations",
108+
)
109+
110+
args = parser.parse_args()
111+
main(args)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
import logging
6+
from typing import List, Tuple, Union
7+
8+
import numpy as np
9+
import torch
10+
import torchvision.transforms as transforms
11+
from PIL import Image
12+
from transformers import AutoConfig, AutoModelForImageClassification
13+
14+
_log = logging.getLogger(__name__)
15+
16+
17+
class DocumentFigureClassifierPredictor:
18+
r"""
19+
Model for classifying document figures.
20+
21+
Classifies figures as 1 out of 16 possible classes.
22+
23+
The classes are:
24+
1. "bar_chart"
25+
2. "bar_code"
26+
3. "chemistry_markush_structure"
27+
4. "chemistry_molecular_structure"
28+
5. "flow_chart"
29+
6. "icon"
30+
7. "line_chart"
31+
8. "logo"
32+
9. "map"
33+
10. "other"
34+
11. "pie_chart"
35+
12. "qr_code"
36+
13. "remote_sensing"
37+
14. "screenshot"
38+
15. "signature"
39+
16. "stamp"
40+
41+
Attributes
42+
----------
43+
_device : str
44+
The device on which the model is loaded (e.g., 'cpu' or 'cuda').
45+
_num_threads : int
46+
Number of threads used for inference when running on CPU.
47+
_model : EfficientNetForImageClassification
48+
Pretrained EfficientNetb0 model.
49+
_image_processor : EfficientNetImageProcessor
50+
Processor for normalizing and preparing input images.
51+
_classes: List[str]:
52+
The classes used by the model.
53+
54+
Methods
55+
-------
56+
__init__(artifacts_path, device, num_threads)
57+
Initializes the DocumentFigureClassifierPredictor with the specified parameters.
58+
info() -> dict:
59+
Retrieves configuration details of the DocumentFigureClassifierPredictor instance.
60+
predict(images) -> List[List[float]]
61+
The confidence scores for the classification of each image.
62+
"""
63+
64+
def __init__(
65+
self,
66+
artifacts_path: str,
67+
device: str = "cpu",
68+
num_threads: int = 4,
69+
):
70+
r"""
71+
Initializes the DocumentFigureClassifierPredictor.
72+
73+
Parameters
74+
----------
75+
artifacts_path : str
76+
Path to the directory containing the pretrained model files.
77+
device : str, optional
78+
Device to run the inference on ('cpu' or 'cuda'), by default "cpu".
79+
num_threads : int, optional
80+
Number of threads for CPU inference, by default 4.
81+
"""
82+
self._device = device
83+
self._num_threads = num_threads
84+
85+
if device == "cpu":
86+
torch.set_num_threads(self._num_threads)
87+
88+
model = AutoModelForImageClassification.from_pretrained(artifacts_path)
89+
self._model = model.to(device)
90+
self._model.eval()
91+
92+
self._image_processor = transforms.Compose(
93+
[
94+
transforms.Resize((224, 224)),
95+
transforms.ToTensor(),
96+
transforms.Normalize(
97+
mean=[0.485, 0.456, 0.406],
98+
std=[0.47853944, 0.4732864, 0.47434163],
99+
),
100+
]
101+
)
102+
103+
config = AutoConfig.from_pretrained(artifacts_path)
104+
105+
self._classes = list(config.id2label.values())
106+
self._classes.sort()
107+
108+
_log.debug("CodeFormulaModel settings: {}".format(self.info()))
109+
110+
def info(self) -> dict:
111+
"""
112+
Retrieves configuration details of the DocumentFigureClassifierPredictor instance.
113+
114+
Returns
115+
-------
116+
dict
117+
A dictionary containing configuration details such as the device,
118+
the number of threads used and the classe sused by the model.
119+
"""
120+
info = {
121+
"device": self._device,
122+
"num_threads": self._num_threads,
123+
"classes": self._classes,
124+
}
125+
return info
126+
127+
def predict(
128+
self, images: List[Union[Image.Image, np.ndarray]]
129+
) -> List[List[Tuple[str, float]]]:
130+
r"""
131+
Performs inference on a batch of figures.
132+
133+
Parameters
134+
----------
135+
images : List[Union[Image.Image, np.ndarray]]
136+
A list of input images for inference. Each image can either be a
137+
PIL.Image.Image object or a NumPy array representing an image.
138+
139+
Returns
140+
-------
141+
List[List[Tuple[str, float]]]
142+
A list of predictions for each input image. Each prediction is a list of
143+
tuples representing the predicted class and confidence score:
144+
- str: The predicted class name for the image.
145+
- float: The confidence score associated with the predicted class,
146+
ranging from 0 to 1.
147+
148+
The predictions for each image are sorted in descending order of confidence.
149+
"""
150+
processed_images = []
151+
for image in images:
152+
if isinstance(image, Image.Image):
153+
processed_images.append(image.convert("RGB"))
154+
elif isinstance(image, np.ndarray):
155+
processed_images.append(Image.fromarray(image).convert("RGB"))
156+
else:
157+
raise TypeError(
158+
"Supported input formats are PIL.Image.Image or numpy.ndarray."
159+
)
160+
images = processed_images
161+
162+
# (batch_size, 3, 224, 224)
163+
images = [self._image_processor(image) for image in images]
164+
images = torch.stack(images).to(self._device)
165+
166+
with torch.no_grad():
167+
logits = self._model(images).logits # (batch_size, num_classes)
168+
probs_batch = logits.softmax(dim=1) # (batch_size, num_classes)
169+
probs_batch = probs_batch.cpu().numpy().tolist()
170+
171+
predictions_batch = []
172+
for probs_image in probs_batch:
173+
preds = [(self._classes[i], prob) for i, prob in enumerate(probs_image)]
174+
preds.sort(key=lambda t: t[1], reverse=True)
175+
predictions_batch.append(preds)
176+
177+
return predictions_batch
47.5 KB
Loading
130 KB
Loading
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
import os
6+
import numpy as np
7+
import pytest
8+
from PIL import Image
9+
10+
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import (
11+
DocumentFigureClassifierPredictor,
12+
)
13+
14+
from huggingface_hub import snapshot_download
15+
16+
17+
@pytest.fixture(scope="module")
18+
def init() -> dict:
19+
r"""
20+
Initialize the testing environment
21+
"""
22+
init = {
23+
"num_threads": 1,
24+
"test_imgs": [
25+
{
26+
"label": "bar_chart",
27+
"image_path": "tests/test_data/figure_classifier/images/bar_chart.jpg",
28+
},
29+
{
30+
"label": "map",
31+
"image_path": "tests/test_data/figure_classifier/images/map.jpg",
32+
},
33+
],
34+
"info": {
35+
"device": "auto",
36+
},
37+
}
38+
39+
# Download models from HF
40+
init["artifact_path"] = snapshot_download(
41+
repo_id="ds4sd/DocumentFigureClassifier", revision="v1.0.0"
42+
)
43+
44+
return init
45+
46+
47+
def test_figure_classifier(init: dict):
48+
r"""
49+
Unit test for the CodeFormulaPredictor
50+
"""
51+
device = "cpu"
52+
num_threads = 2
53+
54+
# Initialize LayoutPredictor
55+
figure_classifier = DocumentFigureClassifierPredictor(
56+
init["artifact_path"], device=device, num_threads=num_threads
57+
)
58+
59+
# Check info
60+
info = figure_classifier.info()
61+
assert info["device"] == device, "Wronly set device"
62+
assert info["num_threads"] == num_threads, "Wronly set number of threads"
63+
64+
# Unsupported input image
65+
is_exception = False
66+
try:
67+
for _ in figure_classifier.predict(["wrong"]):
68+
pass
69+
except TypeError:
70+
is_exception = True
71+
assert is_exception
72+
73+
# Predict on test images, not batched
74+
for d in init["test_imgs"]:
75+
label = d["label"]
76+
img_path = d["image_path"]
77+
78+
with Image.open(img_path) as img:
79+
80+
output = figure_classifier.predict([img])
81+
predicted_class = output[0][0][0]
82+
83+
assert predicted_class == label
84+
85+
# Load images as numpy arrays
86+
np_arr = np.asarray(img)
87+
output = figure_classifier.predict([np_arr])
88+
predicted_class = output[0][0][0]
89+
90+
assert predicted_class == label
91+
92+
# Predict on test images, batched
93+
labels = [d['label'] for d in init["test_imgs"]]
94+
images = [Image.open(d["image_path"]) for d in init["test_imgs"]]
95+
96+
outputs = figure_classifier.predict(images)
97+
outputs = [output[0][0] for output in outputs]
98+
assert outputs == labels

0 commit comments

Comments
 (0)