|
1 | 1 | import copy |
| 2 | +import operator |
2 | 3 | import warnings |
3 | 4 | from pathlib import Path |
4 | 5 | from typing import Iterable, Optional, Tuple, cast |
@@ -189,144 +190,137 @@ def __call__( |
189 | 190 | page.predictions.tablestructure = ( |
190 | 191 | TableStructurePrediction() |
191 | 192 | ) # dummy |
192 | | - cells_orientation = detect_orientation(page.cells) |
193 | 193 | # Keep only table bboxes |
194 | 194 | in_tables_clusters = [ |
195 | 195 | cluster |
196 | 196 | for cluster in page.predictions.layout.clusters |
197 | 197 | if cluster.label in self._table_labels |
198 | 198 | ] |
199 | 199 |
|
200 | | - if not len(in_tables_clusters): |
| 200 | + if not in_tables_clusters: |
201 | 201 | yield page |
202 | 202 | continue |
203 | 203 | # Rotate and scale table image |
204 | 204 | page_im = cast(Image, page.get_image()) |
205 | | - original_scaled_page_size = ( |
206 | | - int(page_im.size[0] * self.scale), |
207 | | - int(page_im.size[1] * self.scale), |
208 | | - ) |
209 | | - scaled_page_im: Image = cast( |
| 205 | + original_scaled_im: Image = cast( |
210 | 206 | Image, page.get_image(scale=self.scale) |
211 | 207 | ) |
212 | | - if cells_orientation: |
213 | | - scaled_page_im = scaled_page_im.rotate( |
214 | | - -cells_orientation, expand=True |
215 | | - ) |
216 | | - page_input = { |
217 | | - "width": scaled_page_im.size[0], |
218 | | - "height": scaled_page_im.size[1], |
219 | | - "image": numpy.asarray(scaled_page_im), |
220 | | - } |
221 | | - # Rotate and scale table cells |
222 | | - in_tables = [ |
223 | | - ( |
224 | | - c, |
225 | | - [ |
226 | | - round(x) * self.scale |
227 | | - for x in _rotate_bbox( |
228 | | - c.bbox, |
229 | | - orientation=-cells_orientation, |
230 | | - im_size=page_im.size, |
231 | | - ) |
232 | | - .to_top_left_origin(page_im.size[1]) |
233 | | - .as_tuple() |
234 | | - ], |
235 | | - ) |
236 | | - for c in in_tables_clusters |
237 | | - ] |
238 | | - table_clusters, table_bboxes = zip(*in_tables) |
239 | | - |
240 | | - if len(table_bboxes): |
241 | | - for table_cluster, tbl_box in in_tables: |
242 | | - # Check if word-level cells are available from backend: |
243 | | - sp = page._backend.get_segmented_page() |
244 | | - if sp is not None: |
245 | | - tcells = sp.get_cells_in_bbox( |
246 | | - cell_unit=TextCellUnit.WORD, |
247 | | - bbox=table_cluster.bbox, |
248 | | - ) |
249 | | - if len(tcells) == 0: |
250 | | - # In case word-level cells yield empty |
251 | | - tcells = table_cluster.cells |
252 | | - else: |
253 | | - # Otherwise - we use normal (line/phrase) cells |
254 | | - tcells = table_cluster.cells |
255 | | - tokens = [] |
256 | | - for c in tcells: |
257 | | - # Only allow non empty strings (spaces) into the cells of a table |
258 | | - if len(c.text.strip()) > 0: |
259 | | - new_cell = copy.deepcopy(c) |
260 | | - new_cell.rect = BoundingRectangle.from_bounding_box( |
261 | | - new_cell.rect.to_bounding_box().scaled( |
262 | | - scale=self.scale |
263 | | - ) |
264 | | - ) |
265 | | - # _rotate_bbox expects the size of the image in |
266 | | - # which the bbox was found |
267 | | - new_bbox = _rotate_bbox( |
268 | | - new_cell.to_bounding_box(), |
269 | | - orientation=-cells_orientation, |
270 | | - im_size=original_scaled_page_size, |
271 | | - ).model_dump() |
272 | | - tokens.append( |
273 | | - { |
274 | | - "id": new_cell.index, |
275 | | - "text": new_cell.text, |
276 | | - "bbox": new_bbox, |
277 | | - } |
278 | | - ) |
279 | | - page_input["tokens"] = tokens |
| 208 | + original_scaled_page_size = original_scaled_im.size |
| 209 | + clusters_with_orientations = sorted( |
| 210 | + ((c, detect_orientation(c.cells)) for c in in_tables_clusters), |
| 211 | + key=operator.itemgetter(1), |
| 212 | + ) |
280 | 213 |
|
281 | | - tf_output = self.tf_predictor.multi_table_predict( |
282 | | - page_input, [tbl_box], do_matching=self.do_cell_matching |
| 214 | + previous_orientation = None |
| 215 | + for table_cluster, table_orientation in clusters_with_orientations: |
| 216 | + # Rotate the image if needed |
| 217 | + if previous_orientation != table_orientation: |
| 218 | + scaled_page_im = original_scaled_im |
| 219 | + if table_orientation: |
| 220 | + scaled_page_im = original_scaled_im.rotate( |
| 221 | + -table_orientation, expand=True |
| 222 | + ) |
| 223 | + page_input = { |
| 224 | + "width": scaled_page_im.size[0], |
| 225 | + "height": scaled_page_im.size[1], |
| 226 | + "image": numpy.asarray(scaled_page_im), |
| 227 | + } |
| 228 | + previous_orientation = table_orientation |
| 229 | + # Rotate and scale the table bbox |
| 230 | + tbl_box = [ |
| 231 | + round(x) * self.scale |
| 232 | + for x in _rotate_bbox( |
| 233 | + table_cluster.bbox, |
| 234 | + orientation=-table_orientation, |
| 235 | + im_size=page_im.size, |
283 | 236 | ) |
284 | | - table_out = tf_output[0] |
285 | | - table_cells = [] |
286 | | - for element in table_out["tf_responses"]: |
287 | | - if not self.do_cell_matching: |
288 | | - the_bbox = BoundingBox.model_validate( |
289 | | - element["bbox"] |
290 | | - ).scaled(1 / self.scale) |
291 | | - text_piece = page._backend.get_text_in_rect( |
292 | | - the_bbox |
| 237 | + .to_top_left_origin(page_im.size[1]) |
| 238 | + .as_tuple() |
| 239 | + ] |
| 240 | + # Check if word-level cells are available from backend: |
| 241 | + sp = page._backend.get_segmented_page() |
| 242 | + if sp is not None: |
| 243 | + tcells = sp.get_cells_in_bbox( |
| 244 | + cell_unit=TextCellUnit.WORD, |
| 245 | + bbox=table_cluster.bbox, |
| 246 | + ) |
| 247 | + if not tcells: |
| 248 | + # In case word-level cells yield empty |
| 249 | + tcells = table_cluster.cells |
| 250 | + else: |
| 251 | + # Otherwise - we use normal (line/phrase) cells |
| 252 | + tcells = table_cluster.cells |
| 253 | + tokens = [] |
| 254 | + for c in tcells: |
| 255 | + # Only allow non empty strings (spaces) into the cells of a table |
| 256 | + if c.text.strip(): |
| 257 | + new_cell = copy.deepcopy(c) |
| 258 | + new_cell.rect = BoundingRectangle.from_bounding_box( |
| 259 | + new_cell.rect.to_bounding_box().scaled( |
| 260 | + scale=self.scale |
293 | 261 | ) |
294 | | - element["bbox"]["token"] = text_piece |
295 | | - element["bbox"] = _rotate_bbox( |
296 | | - BoundingBox.model_validate(element["bbox"]), |
297 | | - orientation=cells_orientation, |
298 | | - im_size=scaled_page_im.size, |
| 262 | + ) |
| 263 | + # _rotate_bbox expects the size of the image in |
| 264 | + # which the bbox was found |
| 265 | + new_bbox = _rotate_bbox( |
| 266 | + new_cell.to_bounding_box(), |
| 267 | + orientation=-table_orientation, |
| 268 | + im_size=original_scaled_page_size, |
299 | 269 | ).model_dump() |
300 | | - tc = TableCell.model_validate(element) |
301 | | - if tc.bbox is not None: |
302 | | - tc.bbox = tc.bbox.scaled(1 / self.scale) |
303 | | - table_cells.append(tc) |
304 | | - |
305 | | - assert "predict_details" in table_out |
306 | | - |
307 | | - # Retrieving cols/rows, after post processing: |
308 | | - num_rows = table_out["predict_details"].get("num_rows", 0) |
309 | | - num_cols = table_out["predict_details"].get("num_cols", 0) |
310 | | - otsl_seq = ( |
311 | | - table_out["predict_details"] |
312 | | - .get("prediction", {}) |
313 | | - .get("rs_seq", []) |
314 | | - ) |
| 270 | + tokens.append( |
| 271 | + { |
| 272 | + "id": new_cell.index, |
| 273 | + "text": new_cell.text, |
| 274 | + "bbox": new_bbox, |
| 275 | + } |
| 276 | + ) |
| 277 | + page_input["tokens"] = tokens |
315 | 278 |
|
316 | | - tbl = Table( |
317 | | - otsl_seq=otsl_seq, |
318 | | - table_cells=table_cells, |
319 | | - num_rows=num_rows, |
320 | | - num_cols=num_cols, |
321 | | - id=table_cluster.id, |
322 | | - page_no=page.page_no, |
323 | | - cluster=table_cluster, |
324 | | - label=table_cluster.label, |
325 | | - ) |
| 279 | + tf_output = self.tf_predictor.multi_table_predict( |
| 280 | + page_input, [tbl_box], do_matching=self.do_cell_matching |
| 281 | + ) |
| 282 | + table_out = tf_output[0] |
| 283 | + table_cells = [] |
| 284 | + for element in table_out["tf_responses"]: |
| 285 | + if not self.do_cell_matching: |
| 286 | + the_bbox = BoundingBox.model_validate( |
| 287 | + element["bbox"] |
| 288 | + ).scaled(1 / self.scale) |
| 289 | + text_piece = page._backend.get_text_in_rect(the_bbox) |
| 290 | + element["bbox"]["token"] = text_piece |
| 291 | + element["bbox"] = _rotate_bbox( |
| 292 | + BoundingBox.model_validate(element["bbox"]), |
| 293 | + orientation=table_orientation, |
| 294 | + im_size=scaled_page_im.size, |
| 295 | + ).model_dump() |
| 296 | + tc = TableCell.model_validate(element) |
| 297 | + if tc.bbox is not None: |
| 298 | + tc.bbox = tc.bbox.scaled(1 / self.scale) |
| 299 | + table_cells.append(tc) |
| 300 | + |
| 301 | + assert "predict_details" in table_out |
| 302 | + |
| 303 | + # Retrieving cols/rows, after post processing: |
| 304 | + num_rows = table_out["predict_details"].get("num_rows", 0) |
| 305 | + num_cols = table_out["predict_details"].get("num_cols", 0) |
| 306 | + otsl_seq = ( |
| 307 | + table_out["predict_details"] |
| 308 | + .get("prediction", {}) |
| 309 | + .get("rs_seq", []) |
| 310 | + ) |
| 311 | + |
| 312 | + tbl = Table( |
| 313 | + otsl_seq=otsl_seq, |
| 314 | + table_cells=table_cells, |
| 315 | + num_rows=num_rows, |
| 316 | + num_cols=num_cols, |
| 317 | + id=table_cluster.id, |
| 318 | + page_no=page.page_no, |
| 319 | + cluster=table_cluster, |
| 320 | + label=table_cluster.label, |
| 321 | + ) |
326 | 322 |
|
327 | | - page.predictions.tablestructure.table_map[ |
328 | | - table_cluster.id |
329 | | - ] = tbl |
| 323 | + page.predictions.tablestructure.table_map[table_cluster.id] = tbl |
330 | 324 |
|
331 | 325 | # For debugging purposes: |
332 | 326 | if settings.debug.visualize_tables: |
|
0 commit comments