Skip to content

Commit a7cd139

Browse files
authored
(retriever) fix remote ocr and pe logic to match local behavior (#1810)
1 parent f3a489e commit a7cd139

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

nemo_retriever/src/nemo_retriever/nim/nim.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def invoke_image_inference_batches(
161161
*,
162162
invoke_url: str,
163163
image_b64_list: Sequence[str],
164+
merge_levels: Optional[Sequence[str]] = None,
164165
api_key: Optional[str] = None,
165166
timeout_s: float = 60.0,
166167
max_batch_size: int = 8,
@@ -174,6 +175,14 @@ def invoke_image_inference_batches(
174175
`invoke_url` may be a single URL or a comma-separated URL list.
175176
When multiple URLs are provided, batches are distributed round-robin.
176177
178+
Parameters
179+
----------
180+
merge_levels
181+
Optional per-image merge level (``"word"``, ``"sentence"``, or
182+
``"paragraph"``). When provided, must have the same length as
183+
*image_b64_list*. Passed as ``merge_levels`` in the JSON payload
184+
so the NIM can apply per-crop merging behaviour.
185+
177186
Returns one response item per input image, in the same order.
178187
"""
179188
invoke_urls = _parse_invoke_urls(invoke_url)
@@ -187,6 +196,9 @@ def invoke_image_inference_batches(
187196
if n == 0:
188197
return []
189198

199+
if merge_levels is not None and len(merge_levels) != n:
200+
raise ValueError(f"merge_levels length ({len(merge_levels)}) must match image_b64_list length ({n})")
201+
190202
ranges = _chunk_ranges(n, int(max_batch_size))
191203
flattened: List[Optional[Any]] = [None] * n
192204

@@ -198,7 +210,9 @@ def _invoke_one_batch(start: int, end: int, endpoint_url: str) -> Tuple[int, int
198210
}
199211
for b64 in image_b64_list[start:end]
200212
]
201-
payload = {"input": inputs}
213+
payload: Dict[str, Any] = {"input": inputs}
214+
if merge_levels is not None:
215+
payload["merge_levels"] = list(merge_levels[start:end])
202216
response_json = _post_with_retries(
203217
invoke_url=endpoint_url,
204218
payload=payload,

nemo_retriever/src/nemo_retriever/ocr/ocr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,11 +562,16 @@ def ocr_page_elements(
562562
crops = _crop_all_from_page(page_image_b64, dets, row_wanted, as_b64=True)
563563
crop_b64s: List[str] = [b64 for _label, _bbox, b64 in crops]
564564
crop_meta: List[Tuple[str, List[float]]] = [(label, bbox) for label, bbox, _b64 in crops]
565+
# Tables need word-level merging; everything else uses paragraph.
566+
crop_merge_levels: List[str] = [
567+
"word" if label == "table" else "paragraph" for label, _bbox, _b64 in crops
568+
]
565569

566570
if crop_b64s:
567571
response_items = invoke_image_inference_batches(
568572
invoke_url=invoke_url,
569573
image_b64_list=crop_b64s,
574+
merge_levels=crop_merge_levels,
570575
api_key=api_key,
571576
timeout_s=float(request_timeout_s),
572577
max_batch_size=int(kwargs.get("inference_batch_size", 8)),

nemo_retriever/src/nemo_retriever/page_elements/page_elements.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def _remote_response_to_detections(
417417
batch_size=1,
418418
label_names=label_names,
419419
)[0]
420-
return _apply_page_elements_v3_postprocess(dets)
420+
return _apply_final_score_filter(_apply_page_elements_v3_postprocess(dets))
421421
except Exception:
422422
pass
423423

@@ -430,7 +430,7 @@ def _remote_response_to_detections(
430430
if isinstance(bb, dict):
431431
try:
432432
dets = _bounding_boxes_to_detections(bb)
433-
return _apply_page_elements_v3_postprocess(dets)
433+
return _apply_final_score_filter(_apply_page_elements_v3_postprocess(dets))
434434
except Exception:
435435
pass
436436

@@ -442,7 +442,7 @@ def _remote_response_to_detections(
442442
if all(isinstance(v, list) for v in cand.values()):
443443
try:
444444
dets = _annotation_dict_to_detections(cand) # type: ignore[arg-type]
445-
return _apply_page_elements_v3_postprocess(dets)
445+
return _apply_final_score_filter(_apply_page_elements_v3_postprocess(dets))
446446
except Exception:
447447
pass
448448

0 commit comments

Comments
 (0)