Skip to content

Commit 216cee0

Browse files
fix(LayoutPredictor): Ensure that the predicted bboxes are minmaxed inside the image boundaries (#42)
Update the LayoutPredictor unit test Signed-off-by: Nikos Livathinos <[email protected]>
1 parent 0683efc commit 216cee0

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ runs/*
1717
.DS_Store
1818
viz/
1919

20+
# VSCode
21+
.vscode
22+
2023
# VIM
2124
*.swp
2225
*.swo

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,15 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
153153

154154
# Check against threshold
155155
if score > self._threshold:
156+
l = min(w, max(0, box[0]))
157+
t = min(h, max(0, box[1]))
158+
r = min(w, max(0, box[2]))
159+
b = min(h, max(0, box[3]))
156160
yield {
157-
"l": box[0],
158-
"t": box[1],
159-
"r": box[2],
160-
"b": box[3],
161+
"l": l,
162+
"t": t,
163+
"r": r,
164+
"b": b,
161165
"label": label,
162166
"confidence": score,
163167
}

tests/test_layout_predictor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,23 @@ def test_layoutpredictor(init: dict):
7878
# Predict on the test image
7979
for img_fn in init["test_imgs"]:
8080
with Image.open(img_fn) as img:
81+
w, h = img.size
8182
# Load images as PIL objects
8283
for i, pred in enumerate(lpredictor.predict(img)):
8384
print("PIL pred: {}".format(pred))
85+
assert pred["l"] >= 0 and pred["l"] <= w
86+
assert pred["t"] >= 0 and pred["t"] <= h
87+
assert pred["r"] >= 0 and pred["r"] <= w
88+
assert pred["b"] >= 0 and pred["b"] <= h
89+
8490
assert i + 1 == init["pred_bboxes"]
8591

8692
# Load images as numpy arrays
8793
np_arr = np.asarray(img)
8894
for i, pred in enumerate(lpredictor.predict(np_arr)):
8995
print("numpy pred: {}".format(pred))
96+
assert pred["l"] >= 0 and pred["l"] <= w
97+
assert pred["t"] >= 0 and pred["t"] <= h
98+
assert pred["r"] >= 0 and pred["r"] <= w
99+
assert pred["b"] >= 0 and pred["b"] <= h
90100
assert i + 1 == init["pred_bboxes"]

0 commit comments

Comments
 (0)