Skip to content

Commit 93ce0ba

Browse files
authored
feat: Add predict_batch to layout predictor (#125)
Signed-off-by: Christoph Auer <[email protected]>
1 parent 0d37642 commit 93ce0ba

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

docling_ibm_models/layoutmodel/layout_predictor.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,91 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
180180
"label": label_str,
181181
"confidence": score,
182182
}
183+
184+
@torch.inference_mode()
185+
def predict_batch(
186+
self, images: List[Union[Image.Image, np.ndarray]]
187+
) -> List[List[dict]]:
188+
"""
189+
Batch prediction for multiple images - more efficient than calling predict() multiple times.
190+
191+
Parameters
192+
----------
193+
images : List[Union[Image.Image, np.ndarray]]
194+
List of images to process in a single batch
195+
196+
Returns
197+
-------
198+
List[List[dict]]
199+
List of prediction lists, one per input image. Each prediction dict contains:
200+
"label", "confidence", "l", "t", "r", "b"
201+
"""
202+
if not images:
203+
return []
204+
205+
# Convert all images to RGB PIL format
206+
pil_images = []
207+
for img in images:
208+
if isinstance(img, Image.Image):
209+
pil_images.append(img.convert("RGB"))
210+
elif isinstance(img, np.ndarray):
211+
pil_images.append(Image.fromarray(img).convert("RGB"))
212+
else:
213+
raise TypeError("Not supported input image format")
214+
215+
# Get target sizes for all images
216+
target_sizes = torch.tensor([img.size[::-1] for img in pil_images])
217+
218+
# Process all images in a single batch
219+
inputs = self._image_processor(images=pil_images, return_tensors="pt").to(
220+
self._device
221+
)
222+
outputs = self._model(**inputs)
223+
224+
# Post-process all results at once
225+
results_list: List[Dict[str, Tensor]] = (
226+
self._image_processor.post_process_object_detection(
227+
outputs,
228+
target_sizes=target_sizes,
229+
threshold=self._threshold,
230+
)
231+
)
232+
233+
# Convert results to standard format for each image
234+
all_predictions = []
235+
236+
for img, results in zip(pil_images, results_list):
237+
w, h = img.size
238+
predictions = []
239+
240+
for score, label_id, box in zip(
241+
results["scores"], results["labels"], results["boxes"]
242+
):
243+
score = float(score.item())
244+
label_id = int(label_id.item()) + self._label_offset
245+
label_str = self._classes_map[label_id]
246+
247+
# Filter out blacklisted classes
248+
if label_str in self._black_classes:
249+
continue
250+
251+
bbox_float = [float(b.item()) for b in box]
252+
l = min(w, max(0, bbox_float[0]))
253+
t = min(h, max(0, bbox_float[1]))
254+
r = min(w, max(0, bbox_float[2]))
255+
b = min(h, max(0, bbox_float[3]))
256+
257+
predictions.append(
258+
{
259+
"l": l,
260+
"t": t,
261+
"r": r,
262+
"b": b,
263+
"label": label_str,
264+
"confidence": score,
265+
}
266+
)
267+
268+
all_predictions.append(predictions)
269+
270+
return all_predictions

0 commit comments

Comments
 (0)