|
3 | 3 | import warnings |
4 | 4 | from collections.abc import Iterable |
5 | 5 | from pathlib import Path |
6 | | -from typing import Optional |
| 6 | +from typing import List, Optional, Union |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | 9 | from docling_core.types.doc import DocItemLabel |
@@ -148,72 +148,90 @@ def draw_clusters_and_cells_side_by_side( |
148 | 148 | def __call__( |
149 | 149 | self, conv_res: ConversionResult, page_batch: Iterable[Page] |
150 | 150 | ) -> Iterable[Page]: |
151 | | - for page in page_batch: |
| 151 | + # Convert to list to allow multiple iterations |
| 152 | + pages = list(page_batch) |
| 153 | + |
| 154 | + # Separate valid and invalid pages |
| 155 | + valid_pages = [] |
| 156 | + valid_page_images: List[Union[Image.Image, np.ndarray]] = [] |
| 157 | + |
| 158 | + for page in pages: |
152 | 159 | assert page._backend is not None |
153 | 160 | if not page._backend.is_valid(): |
154 | | - yield page |
155 | | - else: |
156 | | - with TimeRecorder(conv_res, "layout"): |
157 | | - assert page.size is not None |
158 | | - page_image = page.get_image(scale=1.0) |
159 | | - assert page_image is not None |
160 | | - |
161 | | - clusters = [] |
162 | | - for ix, pred_item in enumerate( |
163 | | - self.layout_predictor.predict(page_image) |
164 | | - ): |
165 | | - label = DocItemLabel( |
166 | | - pred_item["label"] |
167 | | - .lower() |
168 | | - .replace(" ", "_") |
169 | | - .replace("-", "_") |
170 | | - ) # Temporary, until docling-ibm-model uses docling-core types |
171 | | - cluster = Cluster( |
172 | | - id=ix, |
173 | | - label=label, |
174 | | - confidence=pred_item["confidence"], |
175 | | - bbox=BoundingBox.model_validate(pred_item), |
176 | | - cells=[], |
177 | | - ) |
178 | | - clusters.append(cluster) |
179 | | - |
180 | | - if settings.debug.visualize_raw_layout: |
181 | | - self.draw_clusters_and_cells_side_by_side( |
182 | | - conv_res, page, clusters, mode_prefix="raw" |
183 | | - ) |
184 | | - |
185 | | - # Apply postprocessing |
186 | | - |
187 | | - processed_clusters, processed_cells = LayoutPostprocessor( |
188 | | - page, clusters, self.options |
189 | | - ).postprocess() |
190 | | - # Note: LayoutPostprocessor updates page.cells and page.parsed_page internally |
191 | | - |
192 | | - with warnings.catch_warnings(): |
193 | | - warnings.filterwarnings( |
194 | | - "ignore", |
195 | | - "Mean of empty slice|invalid value encountered in scalar divide", |
196 | | - RuntimeWarning, |
197 | | - "numpy", |
198 | | - ) |
199 | | - |
200 | | - conv_res.confidence.pages[page.page_no].layout_score = float( |
201 | | - np.mean([c.confidence for c in processed_clusters]) |
202 | | - ) |
203 | | - |
204 | | - conv_res.confidence.pages[page.page_no].ocr_score = float( |
205 | | - np.mean( |
206 | | - [c.confidence for c in processed_cells if c.from_ocr] |
207 | | - ) |
208 | | - ) |
209 | | - |
210 | | - page.predictions.layout = LayoutPrediction( |
211 | | - clusters=processed_clusters |
212 | | - ) |
213 | | - |
214 | | - if settings.debug.visualize_layout: |
215 | | - self.draw_clusters_and_cells_side_by_side( |
216 | | - conv_res, page, processed_clusters, mode_prefix="postprocessed" |
217 | | - ) |
| 161 | + continue |
218 | 162 |
|
| 163 | + assert page.size is not None |
| 164 | + page_image = page.get_image(scale=1.0) |
| 165 | + assert page_image is not None |
| 166 | + |
| 167 | + valid_pages.append(page) |
| 168 | + valid_page_images.append(page_image) |
| 169 | + |
| 170 | + # Process all valid pages with batch prediction |
| 171 | + batch_predictions = [] |
| 172 | + if valid_page_images: |
| 173 | + with TimeRecorder(conv_res, "layout"): |
| 174 | + batch_predictions = self.layout_predictor.predict_batch( # type: ignore[attr-defined] |
| 175 | + valid_page_images |
| 176 | + ) |
| 177 | + |
| 178 | + # Process each page with its predictions |
| 179 | + valid_page_idx = 0 |
| 180 | + for page in pages: |
| 181 | + assert page._backend is not None |
| 182 | + if not page._backend.is_valid(): |
219 | 183 | yield page |
| 184 | + continue |
| 185 | + |
| 186 | + page_predictions = batch_predictions[valid_page_idx] |
| 187 | + valid_page_idx += 1 |
| 188 | + |
| 189 | + clusters = [] |
| 190 | + for ix, pred_item in enumerate(page_predictions): |
| 191 | + label = DocItemLabel( |
| 192 | + pred_item["label"].lower().replace(" ", "_").replace("-", "_") |
| 193 | + ) # Temporary, until docling-ibm-model uses docling-core types |
| 194 | + cluster = Cluster( |
| 195 | + id=ix, |
| 196 | + label=label, |
| 197 | + confidence=pred_item["confidence"], |
| 198 | + bbox=BoundingBox.model_validate(pred_item), |
| 199 | + cells=[], |
| 200 | + ) |
| 201 | + clusters.append(cluster) |
| 202 | + |
| 203 | + if settings.debug.visualize_raw_layout: |
| 204 | + self.draw_clusters_and_cells_side_by_side( |
| 205 | + conv_res, page, clusters, mode_prefix="raw" |
| 206 | + ) |
| 207 | + |
| 208 | + # Apply postprocessing |
| 209 | + processed_clusters, processed_cells = LayoutPostprocessor( |
| 210 | + page, clusters, self.options |
| 211 | + ).postprocess() |
| 212 | + # Note: LayoutPostprocessor updates page.cells and page.parsed_page internally |
| 213 | + |
| 214 | + with warnings.catch_warnings(): |
| 215 | + warnings.filterwarnings( |
| 216 | + "ignore", |
| 217 | + "Mean of empty slice|invalid value encountered in scalar divide", |
| 218 | + RuntimeWarning, |
| 219 | + "numpy", |
| 220 | + ) |
| 221 | + |
| 222 | + conv_res.confidence.pages[page.page_no].layout_score = float( |
| 223 | + np.mean([c.confidence for c in processed_clusters]) |
| 224 | + ) |
| 225 | + |
| 226 | + conv_res.confidence.pages[page.page_no].ocr_score = float( |
| 227 | + np.mean([c.confidence for c in processed_cells if c.from_ocr]) |
| 228 | + ) |
| 229 | + |
| 230 | + page.predictions.layout = LayoutPrediction(clusters=processed_clusters) |
| 231 | + |
| 232 | + if settings.debug.visualize_layout: |
| 233 | + self.draw_clusters_and_cells_side_by_side( |
| 234 | + conv_res, page, processed_clusters, mode_prefix="postprocessed" |
| 235 | + ) |
| 236 | + |
| 237 | + yield page |
0 commit comments