Skip to content

Commit d9efe73

Browse files
Fix warnings in TF (#14)
* fix:test_tf_predictor: Ensure all bboxes are always valid before drawing Signed-off-by: Nikos Livathinos <[email protected]> * fix::TableFormer:Update torch API call to suppress warnings: - torch.load(): Add parameter `weights_only` - torchvision.models.resnet18(): Remove parameter `pretrained` Signed-off-by: Nikos Livathinos <[email protected]> --------- Signed-off-by: Nikos Livathinos <[email protected]>
1 parent 90cb234 commit d9efe73

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

docling_ibm_models/tableformer/models/common/base_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def _load_best_checkpoint(self):
257257
self._log().info(
258258
"Loading model checkpoint file: {}".format(checkpoint_file)
259259
)
260-
saved_model = torch.load(checkpoint_file, map_location=self._device)
260+
saved_model = torch.load(
261+
checkpoint_file, map_location=self._device, weights_only=False
262+
)
261263
return saved_model, checkpoint_file
262264
except RuntimeError:
263265
self._log().error("Cannot load file: {}".format(checkpoint_file))

docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, enc_image_size, enc_dim=512):
3030
self.enc_image_size = enc_image_size
3131
self._encoder_dim = enc_dim
3232

33-
resnet = torchvision.models.resnet18(pretrained=False)
33+
resnet = torchvision.models.resnet18()
3434
modules = list(resnet.children())[:-3]
3535

3636
self._resnet = nn.Sequential(*modules)

tests/test_tf_predictor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -525,24 +525,24 @@ def test_tf_predictor():
525525

526526
xt0 = table_bboxes[t][0]
527527
yt0 = table_bboxes[t][1]
528-
xt1 = table_bboxes[t][2]
529-
yt1 = table_bboxes[t][3]
528+
xt1 = max(xt0, table_bboxes[t][2])
529+
yt1 = max(yt0, table_bboxes[t][3])
530530
img1.rectangle(((xt0, yt0), (xt1, yt1)), outline="pink", width=5)
531531

532532
if viz:
533533
# Visualize original OCR words:
534534
for iocr_word in iocr_page["tokens"]:
535535
xi0 = iocr_word["bbox"]["l"]
536536
yi0 = iocr_word["bbox"]["t"]
537-
xi1 = iocr_word["bbox"]["r"]
538-
yi1 = iocr_word["bbox"]["b"]
537+
xi1 = max(xi0, iocr_word["bbox"]["r"])
538+
yi1 = max(yi0, iocr_word["bbox"]["b"])
539539
img1.rectangle(((xi0, yi0), (xi1, yi1)), outline="gray")
540540
# Visualize original docling_ibm_models.tableformer predictions:
541541
for predicted_bbox in predict_details["prediction_bboxes_page"]:
542542
xp0 = predicted_bbox[0] - 1
543543
yp0 = predicted_bbox[1] - 1
544-
xp1 = predicted_bbox[2] + 1
545-
yp1 = predicted_bbox[3] + 1
544+
xp1 = max(xp0, predicted_bbox[2] + 1)
545+
yp1 = max(yp0, predicted_bbox[3] + 1)
546546
img1.rectangle(((xp0, yp0), (xp1, yp1)), outline="green")
547547

548548
# Check the structure of the list items
@@ -565,14 +565,14 @@ def test_tf_predictor():
565565
for text_cell in response["text_cell_bboxes"]:
566566
xc0 = text_cell["l"]
567567
yc0 = text_cell["t"]
568-
xc1 = text_cell["r"]
569-
yc1 = text_cell["b"]
568+
xc1 = max(xc0, text_cell["r"])
569+
yc1 = max(yc0, text_cell["b"])
570570
img1.rectangle(((xc0, yc0), (xc1, yc1)), outline="red")
571571

572572
x0 = response["bbox"]["l"] - 2
573573
y0 = response["bbox"]["t"] - 2
574-
x1 = response["bbox"]["r"] + 2
575-
y1 = response["bbox"]["b"] + 2
574+
x1 = max(x0, response["bbox"]["r"] + 2)
575+
y1 = max(y0, response["bbox"]["b"] + 2)
576576

577577
if response["column_header"]:
578578
img1.rectangle(

0 commit comments

Comments
 (0)