Skip to content

Commit a258d52

Browse files
authored
feat: add visualizers (#263)
* feat: add visualizers Signed-off-by: Panos Vagenas <[email protected]> * make visualizers composable Signed-off-by: Panos Vagenas <[email protected]> * use BoundingRectangle instead of BoundingBox Signed-off-by: Panos Vagenas <[email protected]> * enforce top-left coordinates Signed-off-by: Panos Vagenas <[email protected]> * narrow down test data to first 3 pages Signed-off-by: Panos Vagenas <[email protected]> * add file deletions Signed-off-by: Panos Vagenas <[email protected]> --------- Signed-off-by: Panos Vagenas <[email protected]>
1 parent 8b676b9 commit a258d52

12 files changed

+438
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Define the visualizer types."""
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Define base classes for visualization."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Optional
5+
6+
from PIL.Image import Image
7+
from pydantic import BaseModel
8+
9+
from docling_core.types.doc import DoclingDocument
10+
11+
12+
class BaseVisualizer(BaseModel, ABC):
13+
"""Visualize base class."""
14+
15+
@abstractmethod
16+
def get_visualization(
17+
self,
18+
*,
19+
doc: DoclingDocument,
20+
**kwargs,
21+
) -> dict[Optional[int], Image]:
22+
"""Get visualization of the document as images by page."""
23+
raise NotImplementedError()
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Define classes for layout visualization."""
2+
3+
from copy import deepcopy
4+
from typing import Literal, Optional, Union
5+
6+
from PIL import ImageDraw, ImageFont
7+
from PIL.Image import Image
8+
from PIL.ImageFont import FreeTypeFont
9+
from pydantic import BaseModel
10+
from typing_extensions import override
11+
12+
from docling_core.transforms.visualizer.base import BaseVisualizer
13+
from docling_core.types.doc import DocItemLabel
14+
from docling_core.types.doc.base import CoordOrigin
15+
from docling_core.types.doc.document import ContentLayer, DocItem, DoclingDocument
16+
from docling_core.types.doc.page import BoundingRectangle, TextCell
17+
18+
19+
class _TLBoundingRectangle(BoundingRectangle):
20+
coord_origin: Literal[CoordOrigin.TOPLEFT] = CoordOrigin.TOPLEFT
21+
22+
23+
class _TLTextCell(TextCell):
24+
rect: _TLBoundingRectangle
25+
26+
27+
class _TLCluster(BaseModel):
28+
id: int
29+
label: DocItemLabel
30+
brec: _TLBoundingRectangle
31+
confidence: float = 1.0
32+
cells: list[_TLTextCell] = []
33+
children: list["_TLCluster"] = [] # Add child cluster support
34+
35+
36+
class LayoutVisualizer(BaseVisualizer):
37+
"""Layout visualizer."""
38+
39+
class Params(BaseModel):
40+
"""Layout visualization parameters."""
41+
42+
show_label: bool = True
43+
44+
base_visualizer: Optional[BaseVisualizer] = None
45+
params: Params = Params()
46+
47+
def _draw_clusters(
48+
self, image: Image, clusters: list[_TLCluster], scale_x: float, scale_y: float
49+
) -> None:
50+
"""Draw clusters on an image."""
51+
draw = ImageDraw.Draw(image, "RGBA")
52+
# Create a smaller font for the labels
53+
font: Union[ImageFont.ImageFont, FreeTypeFont]
54+
try:
55+
font = ImageFont.truetype("arial.ttf", 12)
56+
except OSError:
57+
# Fallback to default font if arial is not available
58+
font = ImageFont.load_default()
59+
for c_tl in clusters:
60+
all_clusters = [c_tl, *c_tl.children]
61+
for c in all_clusters:
62+
# Draw cells first (underneath)
63+
cell_color = (0, 0, 0, 40) # Transparent black for cells
64+
for tc in c.cells:
65+
cx0, cy0, cx1, cy1 = tc.rect.to_bounding_box().as_tuple()
66+
cx0 *= scale_x
67+
cx1 *= scale_x
68+
cy0 *= scale_y
69+
cy1 *= scale_y
70+
71+
draw.rectangle(
72+
[(cx0, cy0), (cx1, cy1)],
73+
outline=None,
74+
fill=cell_color,
75+
)
76+
# Draw cluster rectangle
77+
x0, y0, x1, y1 = c.brec.to_bounding_box().as_tuple()
78+
x0 *= scale_x
79+
x1 *= scale_x
80+
y0 *= scale_y
81+
y1 *= scale_y
82+
83+
cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70)
84+
cluster_outline_color = (
85+
*list(DocItemLabel.get_color(c.label)),
86+
255,
87+
)
88+
draw.rectangle(
89+
[(x0, y0), (x1, y1)],
90+
outline=cluster_outline_color,
91+
fill=cluster_fill_color,
92+
)
93+
94+
if self.params.show_label:
95+
# Add label name and confidence
96+
label_text = f"{c.label.name} ({c.confidence:.2f})"
97+
# Create semi-transparent background for text
98+
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
99+
text_bg_padding = 2
100+
draw.rectangle(
101+
[
102+
(
103+
text_bbox[0] - text_bg_padding,
104+
text_bbox[1] - text_bg_padding,
105+
),
106+
(
107+
text_bbox[2] + text_bg_padding,
108+
text_bbox[3] + text_bg_padding,
109+
),
110+
],
111+
fill=(255, 255, 255, 180), # Semi-transparent white
112+
)
113+
# Draw text
114+
draw.text(
115+
(x0, y0),
116+
label_text,
117+
fill=(0, 0, 0, 255), # Solid black
118+
font=font,
119+
)
120+
121+
def _draw_doc_layout(
122+
self, doc: DoclingDocument, images: Optional[dict[Optional[int], Image]] = None
123+
):
124+
"""Draw the document clusters and optionaly the reading order."""
125+
clusters = []
126+
my_images = images or {}
127+
prev_image = None
128+
prev_page_nr = None
129+
for idx, (elem, _) in enumerate(
130+
doc.iterate_items(
131+
included_content_layers={ContentLayer.BODY, ContentLayer.FURNITURE}
132+
)
133+
):
134+
if not isinstance(elem, DocItem):
135+
continue
136+
if len(elem.prov) == 0:
137+
continue # Skip elements without provenances
138+
prov = elem.prov[0]
139+
page_nr = prov.page_no
140+
image = my_images.get(page_nr)
141+
142+
if prev_page_nr is None or page_nr > prev_page_nr: # new page begins
143+
# complete previous drawing
144+
if prev_page_nr is not None and prev_image and clusters:
145+
self._draw_clusters(
146+
image=prev_image,
147+
clusters=clusters,
148+
scale_x=prev_image.width / doc.pages[prev_page_nr].size.width,
149+
scale_y=prev_image.height / doc.pages[prev_page_nr].size.height,
150+
)
151+
clusters = []
152+
153+
if image is None:
154+
page_image = doc.pages[page_nr].image
155+
if page_image is None or (pil_img := page_image.pil_image) is None:
156+
raise RuntimeError("Cannot visualize document without images")
157+
else:
158+
image = deepcopy(pil_img)
159+
my_images[page_nr] = image
160+
tlo_bbox = prov.bbox.to_top_left_origin(
161+
page_height=doc.pages[prov.page_no].size.height
162+
)
163+
cluster = _TLCluster(
164+
id=idx,
165+
label=elem.label,
166+
brec=_TLBoundingRectangle.from_bounding_box(bbox=tlo_bbox),
167+
cells=[],
168+
)
169+
clusters.append(cluster)
170+
171+
prev_page_nr = page_nr
172+
prev_image = image
173+
174+
# complete last drawing
175+
if prev_page_nr is not None and prev_image and clusters:
176+
self._draw_clusters(
177+
image=prev_image,
178+
clusters=clusters,
179+
scale_x=prev_image.width / doc.pages[prev_page_nr].size.width,
180+
scale_y=prev_image.height / doc.pages[prev_page_nr].size.height,
181+
)
182+
183+
return my_images
184+
185+
@override
186+
def get_visualization(
187+
self,
188+
*,
189+
doc: DoclingDocument,
190+
**kwargs,
191+
) -> dict[Optional[int], Image]:
192+
"""Get visualization of the document as images by page."""
193+
base_images = (
194+
self.base_visualizer.get_visualization(doc=doc, **kwargs)
195+
if self.base_visualizer
196+
else None
197+
)
198+
return self._draw_doc_layout(
199+
doc=doc,
200+
images=base_images,
201+
)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Define classes for reading order visualization."""
2+
3+
from copy import deepcopy
4+
from typing import Optional
5+
6+
from PIL import ImageDraw
7+
from PIL.Image import Image
8+
from typing_extensions import override
9+
10+
from docling_core.transforms.visualizer.base import BaseVisualizer
11+
from docling_core.types.doc.document import ContentLayer, DocItem, DoclingDocument
12+
13+
14+
class ReadingOrderVisualizer(BaseVisualizer):
15+
"""Reading order visualizer."""
16+
17+
base_visualizer: Optional[BaseVisualizer] = None
18+
19+
def _draw_arrow(
20+
self,
21+
draw: ImageDraw.ImageDraw,
22+
arrow_coords: tuple[float, float, float, float],
23+
line_width: int = 2,
24+
color: str = "red",
25+
):
26+
"""Draw an arrow inside the given draw object."""
27+
x0, y0, x1, y1 = arrow_coords
28+
29+
# Arrow parameters
30+
start_point = (x0, y0) # Starting point of the arrow
31+
end_point = (x1, y1) # Ending point of the arrow
32+
arrowhead_length = 20 # Length of the arrowhead
33+
arrowhead_width = 10 # Width of the arrowhead
34+
35+
# Draw the arrow shaft (line)
36+
draw.line([start_point, end_point], fill=color, width=line_width)
37+
38+
# Calculate the arrowhead points
39+
dx = end_point[0] - start_point[0]
40+
dy = end_point[1] - start_point[1]
41+
angle = (dx**2 + dy**2) ** 0.5 + 0.01 # Length of the arrow shaft
42+
43+
# Normalized direction vector for the arrow shaft
44+
ux, uy = dx / angle, dy / angle
45+
46+
# Base of the arrowhead
47+
base_x = end_point[0] - ux * arrowhead_length
48+
base_y = end_point[1] - uy * arrowhead_length
49+
50+
# Left and right points of the arrowhead
51+
left_x = base_x - uy * arrowhead_width
52+
left_y = base_y + ux * arrowhead_width
53+
right_x = base_x + uy * arrowhead_width
54+
right_y = base_y - ux * arrowhead_width
55+
56+
# Draw the arrowhead (triangle)
57+
draw.polygon(
58+
[end_point, (left_x, left_y), (right_x, right_y)],
59+
fill=color,
60+
)
61+
return draw
62+
63+
def _draw_doc_reading_order(
64+
self,
65+
doc: DoclingDocument,
66+
images: Optional[dict[Optional[int], Image]] = None,
67+
):
68+
"""Draw the reading order."""
69+
# draw = ImageDraw.Draw(image)
70+
x0, y0 = None, None
71+
my_images: dict[Optional[int], Image] = images or {}
72+
prev_page = None
73+
for elem, _ in doc.iterate_items(
74+
included_content_layers={ContentLayer.BODY, ContentLayer.FURNITURE},
75+
):
76+
if not isinstance(elem, DocItem):
77+
continue
78+
if len(elem.prov) == 0:
79+
continue # Skip elements without provenances
80+
prov = elem.prov[0]
81+
page_no = prov.page_no
82+
image = my_images.get(page_no)
83+
84+
if image is None or prev_page is None or page_no > prev_page:
85+
# new page begins
86+
prev_page = page_no
87+
x0 = y0 = None
88+
89+
if image is None:
90+
page_image = doc.pages[page_no].image
91+
if page_image is None or (pil_img := page_image.pil_image) is None:
92+
raise RuntimeError("Cannot visualize document without images")
93+
else:
94+
image = deepcopy(pil_img)
95+
my_images[page_no] = image
96+
draw = ImageDraw.Draw(image)
97+
98+
# if prov.page_no not in true_doc.pages or prov.page_no != 1:
99+
# logging.error(f"{prov.page_no} not in true_doc.pages -> skipping! ")
100+
# continue
101+
102+
tlo_bbox = prov.bbox.to_top_left_origin(
103+
page_height=doc.pages[prov.page_no].size.height
104+
)
105+
ro_bbox = tlo_bbox.normalized(doc.pages[prov.page_no].size)
106+
ro_bbox.l = round(ro_bbox.l * image.width) # noqa: E741
107+
ro_bbox.r = round(ro_bbox.r * image.width)
108+
ro_bbox.t = round(ro_bbox.t * image.height)
109+
ro_bbox.b = round(ro_bbox.b * image.height)
110+
111+
if ro_bbox.b > ro_bbox.t:
112+
ro_bbox.b, ro_bbox.t = ro_bbox.t, ro_bbox.b
113+
114+
if x0 is None and y0 is None:
115+
x0 = (ro_bbox.l + ro_bbox.r) / 2.0
116+
y0 = (ro_bbox.b + ro_bbox.t) / 2.0
117+
else:
118+
assert x0 is not None
119+
assert y0 is not None
120+
121+
x1 = (ro_bbox.l + ro_bbox.r) / 2.0
122+
y1 = (ro_bbox.b + ro_bbox.t) / 2.0
123+
124+
draw = self._draw_arrow(
125+
draw=draw,
126+
arrow_coords=(x0, y0, x1, y1),
127+
line_width=2,
128+
color="red",
129+
)
130+
x0, y0 = x1, y1
131+
return my_images
132+
133+
@override
134+
def get_visualization(
135+
self,
136+
*,
137+
doc: DoclingDocument,
138+
**kwargs,
139+
) -> dict[Optional[int], Image]:
140+
"""Get visualization of the document as images by page."""
141+
base_images = (
142+
self.base_visualizer.get_visualization(doc=doc, **kwargs)
143+
if self.base_visualizer
144+
else None
145+
)
146+
return self._draw_doc_reading_order(
147+
doc=doc,
148+
images=base_images,
149+
)

0 commit comments

Comments
 (0)