|
| 1 | +"""Define classes for layout visualization.""" |
| 2 | + |
| 3 | +from copy import deepcopy |
| 4 | +from typing import Literal, Optional, Union |
| 5 | + |
| 6 | +from PIL import ImageDraw, ImageFont |
| 7 | +from PIL.Image import Image |
| 8 | +from PIL.ImageFont import FreeTypeFont |
| 9 | +from pydantic import BaseModel |
| 10 | +from typing_extensions import override |
| 11 | + |
| 12 | +from docling_core.transforms.visualizer.base import BaseVisualizer |
| 13 | +from docling_core.types.doc import DocItemLabel |
| 14 | +from docling_core.types.doc.base import CoordOrigin |
| 15 | +from docling_core.types.doc.document import ContentLayer, DocItem, DoclingDocument |
| 16 | +from docling_core.types.doc.page import BoundingRectangle, TextCell |
| 17 | + |
| 18 | + |
| 19 | +class _TLBoundingRectangle(BoundingRectangle): |
| 20 | + coord_origin: Literal[CoordOrigin.TOPLEFT] = CoordOrigin.TOPLEFT |
| 21 | + |
| 22 | + |
| 23 | +class _TLTextCell(TextCell): |
| 24 | + rect: _TLBoundingRectangle |
| 25 | + |
| 26 | + |
| 27 | +class _TLCluster(BaseModel): |
| 28 | + id: int |
| 29 | + label: DocItemLabel |
| 30 | + brec: _TLBoundingRectangle |
| 31 | + confidence: float = 1.0 |
| 32 | + cells: list[_TLTextCell] = [] |
| 33 | + children: list["_TLCluster"] = [] # Add child cluster support |
| 34 | + |
| 35 | + |
| 36 | +class LayoutVisualizer(BaseVisualizer): |
| 37 | + """Layout visualizer.""" |
| 38 | + |
| 39 | + class Params(BaseModel): |
| 40 | + """Layout visualization parameters.""" |
| 41 | + |
| 42 | + show_label: bool = True |
| 43 | + |
| 44 | + base_visualizer: Optional[BaseVisualizer] = None |
| 45 | + params: Params = Params() |
| 46 | + |
| 47 | + def _draw_clusters( |
| 48 | + self, image: Image, clusters: list[_TLCluster], scale_x: float, scale_y: float |
| 49 | + ) -> None: |
| 50 | + """Draw clusters on an image.""" |
| 51 | + draw = ImageDraw.Draw(image, "RGBA") |
| 52 | + # Create a smaller font for the labels |
| 53 | + font: Union[ImageFont.ImageFont, FreeTypeFont] |
| 54 | + try: |
| 55 | + font = ImageFont.truetype("arial.ttf", 12) |
| 56 | + except OSError: |
| 57 | + # Fallback to default font if arial is not available |
| 58 | + font = ImageFont.load_default() |
| 59 | + for c_tl in clusters: |
| 60 | + all_clusters = [c_tl, *c_tl.children] |
| 61 | + for c in all_clusters: |
| 62 | + # Draw cells first (underneath) |
| 63 | + cell_color = (0, 0, 0, 40) # Transparent black for cells |
| 64 | + for tc in c.cells: |
| 65 | + cx0, cy0, cx1, cy1 = tc.rect.to_bounding_box().as_tuple() |
| 66 | + cx0 *= scale_x |
| 67 | + cx1 *= scale_x |
| 68 | + cy0 *= scale_y |
| 69 | + cy1 *= scale_y |
| 70 | + |
| 71 | + draw.rectangle( |
| 72 | + [(cx0, cy0), (cx1, cy1)], |
| 73 | + outline=None, |
| 74 | + fill=cell_color, |
| 75 | + ) |
| 76 | + # Draw cluster rectangle |
| 77 | + x0, y0, x1, y1 = c.brec.to_bounding_box().as_tuple() |
| 78 | + x0 *= scale_x |
| 79 | + x1 *= scale_x |
| 80 | + y0 *= scale_y |
| 81 | + y1 *= scale_y |
| 82 | + |
| 83 | + cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70) |
| 84 | + cluster_outline_color = ( |
| 85 | + *list(DocItemLabel.get_color(c.label)), |
| 86 | + 255, |
| 87 | + ) |
| 88 | + draw.rectangle( |
| 89 | + [(x0, y0), (x1, y1)], |
| 90 | + outline=cluster_outline_color, |
| 91 | + fill=cluster_fill_color, |
| 92 | + ) |
| 93 | + |
| 94 | + if self.params.show_label: |
| 95 | + # Add label name and confidence |
| 96 | + label_text = f"{c.label.name} ({c.confidence:.2f})" |
| 97 | + # Create semi-transparent background for text |
| 98 | + text_bbox = draw.textbbox((x0, y0), label_text, font=font) |
| 99 | + text_bg_padding = 2 |
| 100 | + draw.rectangle( |
| 101 | + [ |
| 102 | + ( |
| 103 | + text_bbox[0] - text_bg_padding, |
| 104 | + text_bbox[1] - text_bg_padding, |
| 105 | + ), |
| 106 | + ( |
| 107 | + text_bbox[2] + text_bg_padding, |
| 108 | + text_bbox[3] + text_bg_padding, |
| 109 | + ), |
| 110 | + ], |
| 111 | + fill=(255, 255, 255, 180), # Semi-transparent white |
| 112 | + ) |
| 113 | + # Draw text |
| 114 | + draw.text( |
| 115 | + (x0, y0), |
| 116 | + label_text, |
| 117 | + fill=(0, 0, 0, 255), # Solid black |
| 118 | + font=font, |
| 119 | + ) |
| 120 | + |
| 121 | + def _draw_doc_layout( |
| 122 | + self, doc: DoclingDocument, images: Optional[dict[Optional[int], Image]] = None |
| 123 | + ): |
| 124 | + """Draw the document clusters and optionaly the reading order.""" |
| 125 | + clusters = [] |
| 126 | + my_images = images or {} |
| 127 | + prev_image = None |
| 128 | + prev_page_nr = None |
| 129 | + for idx, (elem, _) in enumerate( |
| 130 | + doc.iterate_items( |
| 131 | + included_content_layers={ContentLayer.BODY, ContentLayer.FURNITURE} |
| 132 | + ) |
| 133 | + ): |
| 134 | + if not isinstance(elem, DocItem): |
| 135 | + continue |
| 136 | + if len(elem.prov) == 0: |
| 137 | + continue # Skip elements without provenances |
| 138 | + prov = elem.prov[0] |
| 139 | + page_nr = prov.page_no |
| 140 | + image = my_images.get(page_nr) |
| 141 | + |
| 142 | + if prev_page_nr is None or page_nr > prev_page_nr: # new page begins |
| 143 | + # complete previous drawing |
| 144 | + if prev_page_nr is not None and prev_image and clusters: |
| 145 | + self._draw_clusters( |
| 146 | + image=prev_image, |
| 147 | + clusters=clusters, |
| 148 | + scale_x=prev_image.width / doc.pages[prev_page_nr].size.width, |
| 149 | + scale_y=prev_image.height / doc.pages[prev_page_nr].size.height, |
| 150 | + ) |
| 151 | + clusters = [] |
| 152 | + |
| 153 | + if image is None: |
| 154 | + page_image = doc.pages[page_nr].image |
| 155 | + if page_image is None or (pil_img := page_image.pil_image) is None: |
| 156 | + raise RuntimeError("Cannot visualize document without images") |
| 157 | + else: |
| 158 | + image = deepcopy(pil_img) |
| 159 | + my_images[page_nr] = image |
| 160 | + tlo_bbox = prov.bbox.to_top_left_origin( |
| 161 | + page_height=doc.pages[prov.page_no].size.height |
| 162 | + ) |
| 163 | + cluster = _TLCluster( |
| 164 | + id=idx, |
| 165 | + label=elem.label, |
| 166 | + brec=_TLBoundingRectangle.from_bounding_box(bbox=tlo_bbox), |
| 167 | + cells=[], |
| 168 | + ) |
| 169 | + clusters.append(cluster) |
| 170 | + |
| 171 | + prev_page_nr = page_nr |
| 172 | + prev_image = image |
| 173 | + |
| 174 | + # complete last drawing |
| 175 | + if prev_page_nr is not None and prev_image and clusters: |
| 176 | + self._draw_clusters( |
| 177 | + image=prev_image, |
| 178 | + clusters=clusters, |
| 179 | + scale_x=prev_image.width / doc.pages[prev_page_nr].size.width, |
| 180 | + scale_y=prev_image.height / doc.pages[prev_page_nr].size.height, |
| 181 | + ) |
| 182 | + |
| 183 | + return my_images |
| 184 | + |
| 185 | + @override |
| 186 | + def get_visualization( |
| 187 | + self, |
| 188 | + *, |
| 189 | + doc: DoclingDocument, |
| 190 | + **kwargs, |
| 191 | + ) -> dict[Optional[int], Image]: |
| 192 | + """Get visualization of the document as images by page.""" |
| 193 | + base_images = ( |
| 194 | + self.base_visualizer.get_visualization(doc=doc, **kwargs) |
| 195 | + if self.base_visualizer |
| 196 | + else None |
| 197 | + ) |
| 198 | + return self._draw_doc_layout( |
| 199 | + doc=doc, |
| 200 | + images=base_images, |
| 201 | + ) |
0 commit comments