Skip to content

Commit 9de1985

Browse files
committed
debug: test grid sampler monkeypatch
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
1 parent 52bebae commit 9de1985

File tree

1 file changed

+116
-1
lines changed

1 file changed

+116
-1
lines changed

docling/models/stages/ocr/nemotron_ocr_model.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import threading
55
from collections.abc import Iterable, Sequence
66
from pathlib import Path
7-
from typing import Optional, Type, TypedDict, cast
7+
from typing import Any, Optional, Type, TypedDict, cast
88

99
import numpy
1010
from docling_core.types.doc import BoundingBox, CoordOrigin
@@ -25,6 +25,118 @@
2525
_log = logging.getLogger(__name__)
2626

2727

28+
class _GridSamplerDebugWrapper:
29+
def __init__(self, original_sampler: Any):
30+
self._original_sampler = original_sampler
31+
self._call_seq = 0
32+
self._lock = threading.Lock()
33+
34+
@staticmethod
35+
def _branch_name(grid: Any) -> str:
36+
shape = getattr(grid, "shape", None)
37+
if shape is None or len(shape) < 2:
38+
return "unknown"
39+
tail = tuple(int(dim) for dim in shape[-2:])
40+
if tail == (8, 32):
41+
return "recognizer"
42+
if tail == (2, 3):
43+
return "relational"
44+
return "unknown"
45+
46+
@staticmethod
47+
def _describe_tensor(name: str, tensor: Any) -> str:
48+
shape = getattr(tensor, "shape", None)
49+
dtype = getattr(tensor, "dtype", None)
50+
device = getattr(tensor, "device", None)
51+
52+
is_contiguous: Any
53+
try:
54+
is_contiguous = tensor.is_contiguous()
55+
except Exception as exc: # pragma: no cover - debug path
56+
is_contiguous = f"err:{exc}"
57+
58+
is_meta = getattr(tensor, "is_meta", "n/a")
59+
60+
stride: Any
61+
try:
62+
stride = tuple(int(v) for v in tensor.stride())
63+
except Exception as exc: # pragma: no cover - debug path
64+
stride = f"err:{exc}"
65+
66+
storage_offset: Any
67+
try:
68+
storage_offset = tensor.storage_offset()
69+
except Exception as exc: # pragma: no cover - debug path
70+
storage_offset = f"err:{exc}"
71+
72+
data_ptr: Any
73+
try:
74+
data_ptr = tensor.data_ptr()
75+
except Exception as exc: # pragma: no cover - debug path
76+
data_ptr = f"err:{exc}"
77+
78+
return (
79+
f"{name}: type={type(tensor)} shape={shape} dtype={dtype} device={device} "
80+
f"contiguous={is_contiguous} is_meta={is_meta} stride={stride} "
81+
f"storage_offset={storage_offset} data_ptr={data_ptr}"
82+
)
83+
84+
def __call__(self, input_tensor: Any, grid: Any, input_indices: Any) -> Any:
85+
with self._lock:
86+
self._call_seq += 1
87+
call_id = self._call_seq
88+
89+
branch_name = self._branch_name(grid)
90+
print(
91+
f"[nemotron-debug] grid-sampler-enter call={call_id} branch={branch_name}"
92+
)
93+
print(f"[nemotron-debug] {self._describe_tensor('input', input_tensor)}")
94+
print(f"[nemotron-debug] {self._describe_tensor('grid', grid)}")
95+
print(
96+
f"[nemotron-debug] {self._describe_tensor('input_indices', input_indices)}"
97+
)
98+
99+
try:
100+
result = self._original_sampler(input_tensor, grid, input_indices)
101+
print(
102+
f"[nemotron-debug] grid-sampler-ok call={call_id} branch={branch_name}"
103+
)
104+
return result
105+
except RuntimeError as exc:
106+
print(
107+
f"[nemotron-debug] grid-sampler-failed call={call_id} "
108+
f"branch={branch_name} error={exc}"
109+
)
110+
111+
cloned_input = input_tensor.contiguous().clone()
112+
cloned_grid = grid.contiguous().clone()
113+
cloned_input_indices = input_indices.contiguous().clone()
114+
115+
print(
116+
f"[nemotron-debug] grid-sampler-retry call={call_id} "
117+
f"branch={branch_name} mode=contiguous_clone"
118+
)
119+
print(
120+
f"[nemotron-debug] {self._describe_tensor('cloned_input', cloned_input)}"
121+
)
122+
print(
123+
f"[nemotron-debug] {self._describe_tensor('cloned_grid', cloned_grid)}"
124+
)
125+
print(
126+
f"[nemotron-debug] "
127+
f"{self._describe_tensor('cloned_input_indices', cloned_input_indices)}"
128+
)
129+
130+
result = self._original_sampler(
131+
cloned_input, cloned_grid, cloned_input_indices
132+
)
133+
print(
134+
f"[nemotron-debug] grid-sampler-retry-ok call={call_id} "
135+
f"branch={branch_name}"
136+
)
137+
return result
138+
139+
28140
class NemotronOcrPrediction(TypedDict):
29141
"""Exact prediction schema returned by `nemotron_ocr`."""
30142

@@ -71,6 +183,9 @@ def __init__(
71183
else None
72184
)
73185
self.reader = NemotronOCR(model_dir=model_dir)
186+
self.reader.grid_sampler = _GridSamplerDebugWrapper(
187+
self.reader.grid_sampler
188+
)
74189
self._reader_debug_lock = threading.Lock()
75190
self._active_reader_calls = 0
76191
self._reader_call_seq = 0

0 commit comments

Comments
 (0)