Skip to content

Commit 0f7ca77

Browse files
cau-gitCopilot
andauthored
feat: Key-value visualizer (#360)
* Add key-value visualizer, and chooser arg on DoclingDocument.get_visualization Signed-off-by: Christoph Auer <[email protected]> * Update docling_core/transforms/visualizer/key_value_visualizer.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Christoph Auer <[email protected]> * Update docling_core/transforms/visualizer/key_value_visualizer.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Christoph Auer <[email protected]> * Address review comments Signed-off-by: Christoph Auer <[email protected]> --------- Signed-off-by: Christoph Auer <[email protected]> Signed-off-by: Christoph Auer <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 93dd0a9 commit 0f7ca77

File tree

2 files changed

+262
-10
lines changed

2 files changed

+262
-10
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Key‑value visualizer overlaying key/value cells and their links on page images.
2+
3+
This module complements :py:class:`layout_visualizer.LayoutVisualizer` by drawing
4+
*key* and *value* cells plus the directed links between them. It can be stacked
5+
on top of any other :py:class:`BaseVisualizer` – e.g. first draw the general
6+
layout, then add the key‑value layer.
7+
"""
8+
9+
from copy import deepcopy
10+
from typing import Optional, Union
11+
12+
from PIL import ImageDraw, ImageFont
13+
from PIL.Image import Image
14+
from PIL.ImageFont import FreeTypeFont
15+
from pydantic import BaseModel
16+
from typing_extensions import override
17+
18+
from docling_core.transforms.visualizer.base import BaseVisualizer
19+
from docling_core.types.doc.document import ContentLayer, DoclingDocument
20+
from docling_core.types.doc.labels import GraphCellLabel, GraphLinkLabel
21+
22+
# ---------------------------------------------------------------------------
23+
# Helper functions / constants
24+
# ---------------------------------------------------------------------------
25+
26+
# Semi‑transparent RGBA colours for key / value cells and their connecting link
27+
_KEY_FILL = (0, 170, 0, 70) # greenish
28+
_VALUE_FILL = (0, 0, 200, 70) # bluish
29+
_LINK_COLOUR = (255, 0, 0, 255) # red line (solid)
30+
31+
_LABEL_TXT_COLOUR = (0, 0, 0, 255)
32+
_LABEL_BG_COLOUR = (255, 255, 255, 180) # semi‑transparent white
33+
34+
35+
class KeyValueVisualizer(BaseVisualizer):
36+
"""Draw key/value graphs stored in :py:attr:`DoclingDocument.key_value_items`."""
37+
38+
class Params(BaseModel):
39+
"""Parameters for KeyValueVisualizer controlling label and cell id display, and content layers to visualize."""
40+
41+
show_label: bool = True # draw cell text close to bbox
42+
show_cell_id: bool = False # annotate each rectangle with its cell_id
43+
content_layers: set[ContentLayer] = {cl for cl in ContentLayer}
44+
45+
base_visualizer: Optional[BaseVisualizer] = None
46+
params: Params = Params()
47+
48+
# ---------------------------------------------------------------------
49+
# Internal helpers
50+
# ---------------------------------------------------------------------
51+
52+
def _cell_fill(self, label: GraphCellLabel) -> tuple[int, int, int, int]:
53+
"""Return RGBA fill colour depending on *label*."""
54+
return _KEY_FILL if label == GraphCellLabel.KEY else _VALUE_FILL
55+
56+
def _draw_key_value_layer(
57+
self,
58+
*,
59+
image: Image,
60+
doc: DoclingDocument,
61+
page_no: int,
62+
scale_x: float,
63+
scale_y: float,
64+
) -> None:
65+
"""Draw every key‑value graph that has cells on *page_no* onto *image*."""
66+
draw = ImageDraw.Draw(image, "RGBA")
67+
# Choose a small truetype font if available, otherwise default bitmap font
68+
font: Union[ImageFont.ImageFont, FreeTypeFont]
69+
try:
70+
font = ImageFont.truetype("arial.ttf", 12)
71+
except OSError:
72+
font = ImageFont.load_default()
73+
74+
for kv_item in doc.key_value_items:
75+
cell_dict = {cell.cell_id: cell for cell in kv_item.graph.cells}
76+
77+
# ------------------------------------------------------------------
78+
# First draw cells (rectangles + optional labels)
79+
# ------------------------------------------------------------------
80+
for cell in cell_dict.values():
81+
if cell.prov is None or cell.prov.page_no != page_no:
82+
continue # skip cells not on this page or without bbox
83+
84+
tl_bbox = cell.prov.bbox.to_top_left_origin(
85+
page_height=doc.pages[page_no].size.height
86+
)
87+
x0, y0, x1, y1 = tl_bbox.as_tuple()
88+
x0 *= scale_x
89+
x1 *= scale_x
90+
y0 *= scale_y
91+
y1 *= scale_y
92+
fill_rgba = self._cell_fill(cell.label)
93+
94+
draw.rectangle(
95+
[(x0, y0), (x1, y1)],
96+
outline=fill_rgba[:-1] + (255,),
97+
fill=fill_rgba,
98+
)
99+
100+
if self.params.show_label:
101+
txt_parts = []
102+
if self.params.show_cell_id:
103+
txt_parts.append(str(cell.cell_id))
104+
txt_parts.append(cell.text)
105+
label_text = " | ".join(txt_parts)
106+
107+
tbx = draw.textbbox((x0, y0), label_text, font=font)
108+
pad = 2
109+
draw.rectangle(
110+
[(tbx[0] - pad, tbx[1] - pad), (tbx[2] + pad, tbx[3] + pad)],
111+
fill=_LABEL_BG_COLOUR,
112+
)
113+
draw.text((x0, y0), label_text, font=font, fill=_LABEL_TXT_COLOUR)
114+
115+
# ------------------------------------------------------------------
116+
# Then draw links (after rectangles so they appear on top)
117+
# ------------------------------------------------------------------
118+
for link in kv_item.graph.links:
119+
if link.label != GraphLinkLabel.TO_VALUE:
120+
# Future‑proof: ignore other link types silently
121+
continue
122+
123+
src_cell = cell_dict.get(link.source_cell_id)
124+
tgt_cell = cell_dict.get(link.target_cell_id)
125+
if src_cell is None or tgt_cell is None:
126+
continue
127+
if (
128+
src_cell.prov is None
129+
or tgt_cell.prov is None
130+
or src_cell.prov.page_no != page_no
131+
or tgt_cell.prov.page_no != page_no
132+
):
133+
continue # only draw if both ends are on this page
134+
135+
def _centre(bbox):
136+
tl = bbox.to_top_left_origin(
137+
page_height=doc.pages[page_no].size.height
138+
)
139+
l, t, r, b = tl.as_tuple()
140+
return ((l + r) / 2 * scale_x, (t + b) / 2 * scale_y)
141+
142+
src_xy = _centre(src_cell.prov.bbox)
143+
tgt_xy = _centre(tgt_cell.prov.bbox)
144+
145+
draw.line([src_xy, tgt_xy], fill=_LINK_COLOUR, width=2)
146+
147+
# draw a small arrow‑head by rendering a short orthogonal line
148+
# segment; exact geometry is not critical for visual inspection
149+
arrow_len = 6
150+
dx = tgt_xy[0] - src_xy[0]
151+
dy = tgt_xy[1] - src_xy[1]
152+
length = (dx**2 + dy**2) ** 0.5 or 1.0
153+
ux, uy = dx / length, dy / length
154+
# perpendicular vector
155+
px, py = -uy, ux
156+
# two points forming the arrow head triangle base
157+
head_base_left = (
158+
tgt_xy[0] - ux * arrow_len - px * arrow_len / 2,
159+
tgt_xy[1] - uy * arrow_len - py * arrow_len / 2,
160+
)
161+
head_base_right = (
162+
tgt_xy[0] - ux * arrow_len + px * arrow_len / 2,
163+
tgt_xy[1] - uy * arrow_len + py * arrow_len / 2,
164+
)
165+
draw.polygon(
166+
[tgt_xy, head_base_left, head_base_right], fill=_LINK_COLOUR
167+
)
168+
169+
# ---------------------------------------------------------------------
170+
# Public API – BaseVisualizer implementation
171+
# ---------------------------------------------------------------------
172+
173+
@override
174+
def get_visualization(
175+
self,
176+
*,
177+
doc: DoclingDocument,
178+
included_content_layers: Optional[set[ContentLayer]] = None,
179+
**kwargs,
180+
) -> dict[Optional[int], Image]:
181+
"""Return page‑wise images with key/value overlay (incl. base layer)."""
182+
base_images = (
183+
self.base_visualizer.get_visualization(
184+
doc=doc, included_content_layers=included_content_layers, **kwargs
185+
)
186+
if self.base_visualizer
187+
else None
188+
)
189+
190+
if included_content_layers is None:
191+
included_content_layers = {cl for cl in ContentLayer}
192+
193+
images: dict[Optional[int], Image] = {}
194+
195+
# Ensure we have page images to draw on
196+
for page_nr, page in doc.pages.items():
197+
base_img = (base_images or {}).get(page_nr)
198+
if base_img is None:
199+
if page.image is None or (pil_img := page.image.pil_image) is None:
200+
raise RuntimeError("Cannot visualize document without page images")
201+
base_img = deepcopy(pil_img)
202+
images[page_nr] = base_img
203+
204+
# Overlay key‑value content
205+
for page_nr, img in images.items(): # type: ignore
206+
assert isinstance(page_nr, int)
207+
scale_x = img.width / doc.pages[page_nr].size.width
208+
scale_y = img.height / doc.pages[page_nr].size.height
209+
self._draw_key_value_layer(
210+
image=img,
211+
doc=doc,
212+
page_no=page_nr,
213+
scale_x=scale_x,
214+
scale_y=scale_y,
215+
)
216+
217+
return images

docling_core/types/doc/document.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5543,27 +5543,62 @@ def get_visualization(
55435543
self,
55445544
show_label: bool = True,
55455545
show_branch_numbering: bool = False,
5546+
viz_mode: Literal["reading_order", "key_value"] = "reading_order",
5547+
show_cell_id: bool = False,
55465548
) -> dict[Optional[int], PILImage.Image]:
5547-
"""Get visualization of the document as images by page."""
5549+
"""Get visualization of the document as images by page.
5550+
5551+
:param show_label: Show labels on elements (applies to all visualizers).
5552+
:type show_label: bool
5553+
:param show_branch_numbering: Show branch numbering (reading order visualizer only).
5554+
:type show_branch_numbering: bool
5555+
:param visualizer: Which visualizer to use. One of 'reading_order' (default), 'key_value'.
5556+
:type visualizer: str
5557+
:param show_cell_id: Show cell IDs (key value visualizer only).
5558+
:type show_cell_id: bool
5559+
5560+
:returns: Dictionary mapping page numbers to PIL images.
5561+
:rtype: dict[Optional[int], PILImage.Image]
5562+
"""
5563+
from docling_core.transforms.visualizer.base import BaseVisualizer
5564+
from docling_core.transforms.visualizer.key_value_visualizer import (
5565+
KeyValueVisualizer,
5566+
)
55485567
from docling_core.transforms.visualizer.layout_visualizer import (
55495568
LayoutVisualizer,
55505569
)
55515570
from docling_core.transforms.visualizer.reading_order_visualizer import (
55525571
ReadingOrderVisualizer,
55535572
)
55545573

5555-
visualizer = ReadingOrderVisualizer(
5556-
base_visualizer=LayoutVisualizer(
5557-
params=LayoutVisualizer.Params(
5574+
visualizer_obj: BaseVisualizer
5575+
if viz_mode == "reading_order":
5576+
visualizer_obj = ReadingOrderVisualizer(
5577+
base_visualizer=LayoutVisualizer(
5578+
params=LayoutVisualizer.Params(
5579+
show_label=show_label,
5580+
),
5581+
),
5582+
params=ReadingOrderVisualizer.Params(
5583+
show_branch_numbering=show_branch_numbering,
5584+
),
5585+
)
5586+
elif viz_mode == "key_value":
5587+
visualizer_obj = KeyValueVisualizer(
5588+
base_visualizer=LayoutVisualizer(
5589+
params=LayoutVisualizer.Params(
5590+
show_label=show_label,
5591+
),
5592+
),
5593+
params=KeyValueVisualizer.Params(
55585594
show_label=show_label,
5595+
show_cell_id=show_cell_id,
55595596
),
5560-
),
5561-
params=ReadingOrderVisualizer.Params(
5562-
show_branch_numbering=show_branch_numbering,
5563-
),
5564-
)
5565-
images = visualizer.get_visualization(doc=self)
5597+
)
5598+
else:
5599+
raise ValueError(f"Unknown visualization mode: {viz_mode}")
55665600

5601+
images = visualizer_obj.get_visualization(doc=self)
55675602
return images
55685603

55695604
@field_validator("version")

0 commit comments

Comments
 (0)