Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 96 additions & 11 deletions paddlex/inference/models/object_detection/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _format_output(self, pred: Sequence[Any]) -> List[dict]:
compatible with SOLOv2 output.
- When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
compatible with Instance Segmentation output.
- When len(pred) >= 2 and pred[2] exists as batch_inds, use batch_inds for grouping.

Returns:
List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
Expand All @@ -185,18 +186,102 @@ def _format_output(self, pred: Sequence[Any]) -> List[dict]:
for i in range(len(pred_class_id))
]

if len(pred) == 3:
# Adapt to Instance Segmentation
pred_mask = []
for idx in range(len(pred[1])):
np_boxes_num = pred[1][idx]
box_idx_end = box_idx_start + np_boxes_num
np_boxes = pred[0][box_idx_start:box_idx_end]
pred_box.append(np_boxes)
# Fix for multi-table cell detection issue:
# When multiple table crops are processed in a single batch, post-NMS ordering
# of detections can shift, causing cell assignments to no longer align with
# their original tables. Solution: group by batch_idx instead of relying on
# the order of flattened predictions.
# Reference: https://github.com/PaddlePaddle/PaddleX/issues/17133

# Check if batch_inds is available (for RT-DETR and similar models)
# RT-DETR models may output batch_inds as a separate output from raw inference
batch_inds = None

# First, try to find batch_inds in the raw RT-DETR outputs
# For RT-DETR, batch_inds might be in a separate output
# We need to distinguish it from masks (which are 3D arrays for Instance Segmentation)
# Check all outputs after box_nums for potential batch_inds
for i in range(2, len(pred)):
candidate = pred[i]
# batch_inds should be 1D array of integers matching boxes length
# masks for Instance Segmentation are typically 3D or have different shape
if (
isinstance(candidate, np.ndarray)
and candidate.ndim == 1
and candidate.dtype in (np.int32, np.int64, np.int8, np.int16)
and len(candidate) == len(pred[0])
and len(pred[1]) > 0
and candidate.max() < len(pred[1]) # batch_id should be < batch_size
and candidate.min() >= 0 # batch_id should be >= 0
):
batch_inds = candidate.astype(np.int32)
break # Found batch_inds, stop searching

# If batch_inds is not found in outputs, construct it from box_nums
# This assumes boxes are ordered correctly before any NMS or reordering
if batch_inds is None and len(pred) >= 2 and len(pred[1]) > 0:
box_nums = pred[1]
if isinstance(box_nums, np.ndarray) and box_nums.ndim == 1:
total_boxes = len(pred[0])
expected_total = int(box_nums.sum())
# Only construct batch_inds if the total matches (boxes haven't been reordered)
if total_boxes == expected_total:
batch_inds = np.zeros(total_boxes, dtype=np.int32)
box_idx = 0
for batch_idx, num_boxes in enumerate(box_nums):
num_boxes = int(num_boxes)
if box_idx + num_boxes <= total_boxes:
batch_inds[box_idx : box_idx + num_boxes] = batch_idx
box_idx += num_boxes
else:
# If mismatch, don't use batch_inds
batch_inds = None
break

# Use batch_inds for grouping if available
# This ensures correct grouping even when post-NMS reorders detections
if batch_inds is not None:
unique_batch_ids = np.unique(batch_inds)
pred_box = []
# Find mask output if exists (for Instance Segmentation)
# Masks are typically not 1D integer arrays (unlike batch_inds)
mask_output_idx = None
if len(pred) >= 3:
for i in range(2, len(pred)):
candidate = pred[i]
# If this is not the batch_inds we found, and it looks like masks
if (
isinstance(candidate, np.ndarray)
and not (candidate.ndim == 1 and candidate.dtype in (np.int32, np.int64, np.int8, np.int16) and len(candidate) == len(pred[0]) and candidate.max() < len(pred[1]) and candidate.min() >= 0)
and len(candidate) == len(pred[0])
):
mask_output_idx = i
break

pred_mask = [] if mask_output_idx is not None else None

# Group by batch_idx instead of slicing (fixes post-NMS ordering issue)
for batch_id in unique_batch_ids:
mask = batch_inds == batch_id
np_boxes = pred[0][mask]
pred_box.append(np_boxes)
if pred_mask is not None and mask_output_idx is not None:
np_masks = pred[mask_output_idx][mask]
pred_mask.append(np_masks)
else:
# Fallback to original box_nums slicing method
if len(pred) == 3:
np_masks = pred[2][box_idx_start:box_idx_end]
pred_mask.append(np_masks)
box_idx_start = box_idx_end
# Adapt to Instance Segmentation
pred_mask = []
for idx in range(len(pred[1])):
np_boxes_num = pred[1][idx]
box_idx_end = box_idx_start + np_boxes_num
np_boxes = pred[0][box_idx_start:box_idx_end]
pred_box.append(np_boxes)
if len(pred) == 3:
np_masks = pred[2][box_idx_start:box_idx_end]
pred_mask.append(np_masks)
box_idx_start = box_idx_end

if len(pred) == 3:
return [
Expand Down