|
| 1 | +"""Key‑value visualizer overlaying key/value cells and their links on page images. |
| 2 | +
|
| 3 | +This module complements :py:class:`layout_visualizer.LayoutVisualizer` by drawing |
| 4 | +*key* and *value* cells plus the directed links between them. It can be stacked |
| 5 | +on top of any other :py:class:`BaseVisualizer` – e.g. first draw the general |
| 6 | +layout, then add the key‑value layer. |
| 7 | +""" |
| 8 | + |
| 9 | +from copy import deepcopy |
| 10 | +from typing import Optional, Union |
| 11 | + |
| 12 | +from PIL import ImageDraw, ImageFont |
| 13 | +from PIL.Image import Image |
| 14 | +from PIL.ImageFont import FreeTypeFont |
| 15 | +from pydantic import BaseModel |
| 16 | +from typing_extensions import override |
| 17 | + |
| 18 | +from docling_core.transforms.visualizer.base import BaseVisualizer |
| 19 | +from docling_core.types.doc.document import ContentLayer, DoclingDocument |
| 20 | +from docling_core.types.doc.labels import GraphCellLabel, GraphLinkLabel |
| 21 | + |
| 22 | +# --------------------------------------------------------------------------- |
| 23 | +# Helper functions / constants |
| 24 | +# --------------------------------------------------------------------------- |
| 25 | + |
| 26 | +# Semi‑transparent RGBA colours for key / value cells and their connecting link |
| 27 | +_KEY_FILL = (0, 170, 0, 70) # greenish |
| 28 | +_VALUE_FILL = (0, 0, 200, 70) # bluish |
| 29 | +_LINK_COLOUR = (255, 0, 0, 255) # red line (solid) |
| 30 | + |
| 31 | +_LABEL_TXT_COLOUR = (0, 0, 0, 255) |
| 32 | +_LABEL_BG_COLOUR = (255, 255, 255, 180) # semi‑transparent white |
| 33 | + |
| 34 | + |
| 35 | +class KeyValueVisualizer(BaseVisualizer): |
| 36 | + """Draw key/value graphs stored in :py:attr:`DoclingDocument.key_value_items`.""" |
| 37 | + |
| 38 | + class Params(BaseModel): |
| 39 | + """Parameters for KeyValueVisualizer controlling label and cell id display, and content layers to visualize.""" |
| 40 | + |
| 41 | + show_label: bool = True # draw cell text close to bbox |
| 42 | + show_cell_id: bool = False # annotate each rectangle with its cell_id |
| 43 | + content_layers: set[ContentLayer] = {cl for cl in ContentLayer} |
| 44 | + |
| 45 | + base_visualizer: Optional[BaseVisualizer] = None |
| 46 | + params: Params = Params() |
| 47 | + |
| 48 | + # --------------------------------------------------------------------- |
| 49 | + # Internal helpers |
| 50 | + # --------------------------------------------------------------------- |
| 51 | + |
| 52 | + def _cell_fill(self, label: GraphCellLabel) -> tuple[int, int, int, int]: |
| 53 | + """Return RGBA fill colour depending on *label*.""" |
| 54 | + return _KEY_FILL if label == GraphCellLabel.KEY else _VALUE_FILL |
| 55 | + |
| 56 | + def _draw_key_value_layer( |
| 57 | + self, |
| 58 | + *, |
| 59 | + image: Image, |
| 60 | + doc: DoclingDocument, |
| 61 | + page_no: int, |
| 62 | + scale_x: float, |
| 63 | + scale_y: float, |
| 64 | + ) -> None: |
| 65 | + """Draw every key‑value graph that has cells on *page_no* onto *image*.""" |
| 66 | + draw = ImageDraw.Draw(image, "RGBA") |
| 67 | + # Choose a small truetype font if available, otherwise default bitmap font |
| 68 | + font: Union[ImageFont.ImageFont, FreeTypeFont] |
| 69 | + try: |
| 70 | + font = ImageFont.truetype("arial.ttf", 12) |
| 71 | + except OSError: |
| 72 | + font = ImageFont.load_default() |
| 73 | + |
| 74 | + for kv_item in doc.key_value_items: |
| 75 | + cell_dict = {cell.cell_id: cell for cell in kv_item.graph.cells} |
| 76 | + |
| 77 | + # ------------------------------------------------------------------ |
| 78 | + # First draw cells (rectangles + optional labels) |
| 79 | + # ------------------------------------------------------------------ |
| 80 | + for cell in cell_dict.values(): |
| 81 | + if cell.prov is None or cell.prov.page_no != page_no: |
| 82 | + continue # skip cells not on this page or without bbox |
| 83 | + |
| 84 | + tl_bbox = cell.prov.bbox.to_top_left_origin( |
| 85 | + page_height=doc.pages[page_no].size.height |
| 86 | + ) |
| 87 | + x0, y0, x1, y1 = tl_bbox.as_tuple() |
| 88 | + x0 *= scale_x |
| 89 | + x1 *= scale_x |
| 90 | + y0 *= scale_y |
| 91 | + y1 *= scale_y |
| 92 | + fill_rgba = self._cell_fill(cell.label) |
| 93 | + |
| 94 | + draw.rectangle( |
| 95 | + [(x0, y0), (x1, y1)], |
| 96 | + outline=fill_rgba[:-1] + (255,), |
| 97 | + fill=fill_rgba, |
| 98 | + ) |
| 99 | + |
| 100 | + if self.params.show_label: |
| 101 | + txt_parts = [] |
| 102 | + if self.params.show_cell_id: |
| 103 | + txt_parts.append(str(cell.cell_id)) |
| 104 | + txt_parts.append(cell.text) |
| 105 | + label_text = " | ".join(txt_parts) |
| 106 | + |
| 107 | + tbx = draw.textbbox((x0, y0), label_text, font=font) |
| 108 | + pad = 2 |
| 109 | + draw.rectangle( |
| 110 | + [(tbx[0] - pad, tbx[1] - pad), (tbx[2] + pad, tbx[3] + pad)], |
| 111 | + fill=_LABEL_BG_COLOUR, |
| 112 | + ) |
| 113 | + draw.text((x0, y0), label_text, font=font, fill=_LABEL_TXT_COLOUR) |
| 114 | + |
| 115 | + # ------------------------------------------------------------------ |
| 116 | + # Then draw links (after rectangles so they appear on top) |
| 117 | + # ------------------------------------------------------------------ |
| 118 | + for link in kv_item.graph.links: |
| 119 | + if link.label != GraphLinkLabel.TO_VALUE: |
| 120 | + # Future‑proof: ignore other link types silently |
| 121 | + continue |
| 122 | + |
| 123 | + src_cell = cell_dict.get(link.source_cell_id) |
| 124 | + tgt_cell = cell_dict.get(link.target_cell_id) |
| 125 | + if src_cell is None or tgt_cell is None: |
| 126 | + continue |
| 127 | + if ( |
| 128 | + src_cell.prov is None |
| 129 | + or tgt_cell.prov is None |
| 130 | + or src_cell.prov.page_no != page_no |
| 131 | + or tgt_cell.prov.page_no != page_no |
| 132 | + ): |
| 133 | + continue # only draw if both ends are on this page |
| 134 | + |
| 135 | + def _centre(bbox): |
| 136 | + tl = bbox.to_top_left_origin( |
| 137 | + page_height=doc.pages[page_no].size.height |
| 138 | + ) |
| 139 | + l, t, r, b = tl.as_tuple() |
| 140 | + return ((l + r) / 2 * scale_x, (t + b) / 2 * scale_y) |
| 141 | + |
| 142 | + src_xy = _centre(src_cell.prov.bbox) |
| 143 | + tgt_xy = _centre(tgt_cell.prov.bbox) |
| 144 | + |
| 145 | + draw.line([src_xy, tgt_xy], fill=_LINK_COLOUR, width=2) |
| 146 | + |
| 147 | + # draw a small arrow‑head by rendering a short orthogonal line |
| 148 | + # segment; exact geometry is not critical for visual inspection |
| 149 | + arrow_len = 6 |
| 150 | + dx = tgt_xy[0] - src_xy[0] |
| 151 | + dy = tgt_xy[1] - src_xy[1] |
| 152 | + length = (dx**2 + dy**2) ** 0.5 or 1.0 |
| 153 | + ux, uy = dx / length, dy / length |
| 154 | + # perpendicular vector |
| 155 | + px, py = -uy, ux |
| 156 | + # two points forming the arrow head triangle base |
| 157 | + head_base_left = ( |
| 158 | + tgt_xy[0] - ux * arrow_len - px * arrow_len / 2, |
| 159 | + tgt_xy[1] - uy * arrow_len - py * arrow_len / 2, |
| 160 | + ) |
| 161 | + head_base_right = ( |
| 162 | + tgt_xy[0] - ux * arrow_len + px * arrow_len / 2, |
| 163 | + tgt_xy[1] - uy * arrow_len + py * arrow_len / 2, |
| 164 | + ) |
| 165 | + draw.polygon( |
| 166 | + [tgt_xy, head_base_left, head_base_right], fill=_LINK_COLOUR |
| 167 | + ) |
| 168 | + |
| 169 | + # --------------------------------------------------------------------- |
| 170 | + # Public API – BaseVisualizer implementation |
| 171 | + # --------------------------------------------------------------------- |
| 172 | + |
| 173 | + @override |
| 174 | + def get_visualization( |
| 175 | + self, |
| 176 | + *, |
| 177 | + doc: DoclingDocument, |
| 178 | + included_content_layers: Optional[set[ContentLayer]] = None, |
| 179 | + **kwargs, |
| 180 | + ) -> dict[Optional[int], Image]: |
| 181 | + """Return page‑wise images with key/value overlay (incl. base layer).""" |
| 182 | + base_images = ( |
| 183 | + self.base_visualizer.get_visualization( |
| 184 | + doc=doc, included_content_layers=included_content_layers, **kwargs |
| 185 | + ) |
| 186 | + if self.base_visualizer |
| 187 | + else None |
| 188 | + ) |
| 189 | + |
| 190 | + if included_content_layers is None: |
| 191 | + included_content_layers = {cl for cl in ContentLayer} |
| 192 | + |
| 193 | + images: dict[Optional[int], Image] = {} |
| 194 | + |
| 195 | + # Ensure we have page images to draw on |
| 196 | + for page_nr, page in doc.pages.items(): |
| 197 | + base_img = (base_images or {}).get(page_nr) |
| 198 | + if base_img is None: |
| 199 | + if page.image is None or (pil_img := page.image.pil_image) is None: |
| 200 | + raise RuntimeError("Cannot visualize document without page images") |
| 201 | + base_img = deepcopy(pil_img) |
| 202 | + images[page_nr] = base_img |
| 203 | + |
| 204 | + # Overlay key‑value content |
| 205 | + for page_nr, img in images.items(): # type: ignore |
| 206 | + assert isinstance(page_nr, int) |
| 207 | + scale_x = img.width / doc.pages[page_nr].size.width |
| 208 | + scale_y = img.height / doc.pages[page_nr].size.height |
| 209 | + self._draw_key_value_layer( |
| 210 | + image=img, |
| 211 | + doc=doc, |
| 212 | + page_no=page_nr, |
| 213 | + scale_x=scale_x, |
| 214 | + scale_y=scale_y, |
| 215 | + ) |
| 216 | + |
| 217 | + return images |
0 commit comments