|
4 | 4 | import threading |
5 | 5 | from collections.abc import Iterable, Sequence |
6 | 6 | from pathlib import Path |
7 | | -from typing import Optional, Type, TypedDict, cast |
| 7 | +from typing import Any, Optional, Type, TypedDict, cast |
8 | 8 |
|
9 | 9 | import numpy |
10 | 10 | from docling_core.types.doc import BoundingBox, CoordOrigin |
|
25 | 25 | _log = logging.getLogger(__name__) |
26 | 26 |
|
27 | 27 |
|
| 28 | +class _GridSamplerDebugWrapper: |
| 29 | + def __init__(self, original_sampler: Any): |
| 30 | + self._original_sampler = original_sampler |
| 31 | + self._call_seq = 0 |
| 32 | + self._lock = threading.Lock() |
| 33 | + |
| 34 | + @staticmethod |
| 35 | + def _branch_name(grid: Any) -> str: |
| 36 | + shape = getattr(grid, "shape", None) |
| 37 | + if shape is None or len(shape) < 2: |
| 38 | + return "unknown" |
| 39 | + tail = tuple(int(dim) for dim in shape[-2:]) |
| 40 | + if tail == (8, 32): |
| 41 | + return "recognizer" |
| 42 | + if tail == (2, 3): |
| 43 | + return "relational" |
| 44 | + return "unknown" |
| 45 | + |
| 46 | + @staticmethod |
| 47 | + def _describe_tensor(name: str, tensor: Any) -> str: |
| 48 | + shape = getattr(tensor, "shape", None) |
| 49 | + dtype = getattr(tensor, "dtype", None) |
| 50 | + device = getattr(tensor, "device", None) |
| 51 | + |
| 52 | + is_contiguous: Any |
| 53 | + try: |
| 54 | + is_contiguous = tensor.is_contiguous() |
| 55 | + except Exception as exc: # pragma: no cover - debug path |
| 56 | + is_contiguous = f"err:{exc}" |
| 57 | + |
| 58 | + is_meta = getattr(tensor, "is_meta", "n/a") |
| 59 | + |
| 60 | + stride: Any |
| 61 | + try: |
| 62 | + stride = tuple(int(v) for v in tensor.stride()) |
| 63 | + except Exception as exc: # pragma: no cover - debug path |
| 64 | + stride = f"err:{exc}" |
| 65 | + |
| 66 | + storage_offset: Any |
| 67 | + try: |
| 68 | + storage_offset = tensor.storage_offset() |
| 69 | + except Exception as exc: # pragma: no cover - debug path |
| 70 | + storage_offset = f"err:{exc}" |
| 71 | + |
| 72 | + data_ptr: Any |
| 73 | + try: |
| 74 | + data_ptr = tensor.data_ptr() |
| 75 | + except Exception as exc: # pragma: no cover - debug path |
| 76 | + data_ptr = f"err:{exc}" |
| 77 | + |
| 78 | + return ( |
| 79 | + f"{name}: type={type(tensor)} shape={shape} dtype={dtype} device={device} " |
| 80 | + f"contiguous={is_contiguous} is_meta={is_meta} stride={stride} " |
| 81 | + f"storage_offset={storage_offset} data_ptr={data_ptr}" |
| 82 | + ) |
| 83 | + |
| 84 | + def __call__(self, input_tensor: Any, grid: Any, input_indices: Any) -> Any: |
| 85 | + with self._lock: |
| 86 | + self._call_seq += 1 |
| 87 | + call_id = self._call_seq |
| 88 | + |
| 89 | + branch_name = self._branch_name(grid) |
| 90 | + print( |
| 91 | + f"[nemotron-debug] grid-sampler-enter call={call_id} branch={branch_name}" |
| 92 | + ) |
| 93 | + print(f"[nemotron-debug] {self._describe_tensor('input', input_tensor)}") |
| 94 | + print(f"[nemotron-debug] {self._describe_tensor('grid', grid)}") |
| 95 | + print( |
| 96 | + f"[nemotron-debug] {self._describe_tensor('input_indices', input_indices)}" |
| 97 | + ) |
| 98 | + |
| 99 | + try: |
| 100 | + result = self._original_sampler(input_tensor, grid, input_indices) |
| 101 | + print( |
| 102 | + f"[nemotron-debug] grid-sampler-ok call={call_id} branch={branch_name}" |
| 103 | + ) |
| 104 | + return result |
| 105 | + except RuntimeError as exc: |
| 106 | + print( |
| 107 | + f"[nemotron-debug] grid-sampler-failed call={call_id} " |
| 108 | + f"branch={branch_name} error={exc}" |
| 109 | + ) |
| 110 | + |
| 111 | + cloned_input = input_tensor.contiguous().clone() |
| 112 | + cloned_grid = grid.contiguous().clone() |
| 113 | + cloned_input_indices = input_indices.contiguous().clone() |
| 114 | + |
| 115 | + print( |
| 116 | + f"[nemotron-debug] grid-sampler-retry call={call_id} " |
| 117 | + f"branch={branch_name} mode=contiguous_clone" |
| 118 | + ) |
| 119 | + print( |
| 120 | + f"[nemotron-debug] {self._describe_tensor('cloned_input', cloned_input)}" |
| 121 | + ) |
| 122 | + print( |
| 123 | + f"[nemotron-debug] {self._describe_tensor('cloned_grid', cloned_grid)}" |
| 124 | + ) |
| 125 | + print( |
| 126 | + f"[nemotron-debug] " |
| 127 | + f"{self._describe_tensor('cloned_input_indices', cloned_input_indices)}" |
| 128 | + ) |
| 129 | + |
| 130 | + result = self._original_sampler( |
| 131 | + cloned_input, cloned_grid, cloned_input_indices |
| 132 | + ) |
| 133 | + print( |
| 134 | + f"[nemotron-debug] grid-sampler-retry-ok call={call_id} " |
| 135 | + f"branch={branch_name}" |
| 136 | + ) |
| 137 | + return result |
| 138 | + |
| 139 | + |
28 | 140 | class NemotronOcrPrediction(TypedDict): |
29 | 141 | """Exact prediction schema returned by `nemotron_ocr`.""" |
30 | 142 |
|
@@ -71,6 +183,9 @@ def __init__( |
71 | 183 | else None |
72 | 184 | ) |
73 | 185 | self.reader = NemotronOCR(model_dir=model_dir) |
| 186 | + self.reader.grid_sampler = _GridSamplerDebugWrapper( |
| 187 | + self.reader.grid_sampler |
| 188 | + ) |
74 | 189 | self._reader_debug_lock = threading.Lock() |
75 | 190 | self._active_reader_calls = 0 |
76 | 191 | self._reader_call_seq = 0 |
|
0 commit comments