Skip to content

Commit be1ba5d

Browse files
committed
fix: improve tick-to-OCR alignment, skip false positive detections
1 parent 137f513 commit be1ba5d

File tree

3 files changed

+105
-16
lines changed

3 files changed

+105
-16
lines changed

chart2csv/core/mistral_ocr.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,14 @@ def process_both_axes(
153153
"type": "text",
154154
"text": """Extract all numbers from these two chart axis images.
155155
156-
Image 1 is the X-axis (horizontal, read left to right).
157-
Image 2 is the Y-axis (vertical, read top to bottom).
156+
Image 1 is the X-axis (horizontal). Read numbers from LEFT to RIGHT.
157+
Image 2 is the Y-axis (vertical). Read numbers from BOTTOM to TOP (as chart axes normally work).
158158
159-
Return JSON format only:
160-
{"x": [list of numbers], "y": [list of numbers]}
159+
Return JSON format only, with numbers in the order you read them:
160+
{"x": [left to right numbers], "y": [bottom to top numbers]}
161161
162-
Example: {"x": [0, 10, 20, 30], "y": [100, 75, 50, 25, 0]}"""
162+
Example for a chart with X: 0,10,20,30 and Y: 0,25,50,75,100:
163+
{"x": [0, 10, 20, 30], "y": [0, 25, 50, 75, 100]}"""
163164
},
164165
{"type": "image_url", "image_url": x_base64},
165166
{"type": "image_url", "image_url": y_base64}

chart2csv/core/ocr.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,20 +118,45 @@ def _extract_with_mistral(
118118
y_values = backend.process_axis_strip(y_strip) if y_strip.size > 0 else []
119119

120120
# Align X values with detected ticks
121+
# Use spacing-based selection: pick ticks that are most evenly spaced
121122
detected_x_ticks = sorted(ticks["x"])
122-
count = min(len(detected_x_ticks), len(x_values))
123-
for i in range(count):
124-
px = detected_x_ticks[i]
125-
val = x_values[i]
126-
ticks_data["x"].append({"pixel": px, "value": val, "text": str(val)})
123+
n_ocr = len(x_values)
124+
n_detected = len(detected_x_ticks)
125+
126+
if n_ocr > 0 and n_detected >= n_ocr:
127+
# Select N evenly-spaced ticks from detected positions
128+
# This handles cases where detector finds extra false positives
129+
if n_detected == n_ocr:
130+
selected_x_ticks = detected_x_ticks
131+
else:
132+
# Pick indices that give the most even spacing
133+
# Use stride to spread selection across detected ticks
134+
# Take from the end of range (skip early false positives near axis)
135+
skip = n_detected - n_ocr
136+
selected_x_ticks = detected_x_ticks[skip:]
137+
138+
for i, px in enumerate(selected_x_ticks):
139+
val = x_values[i]
140+
ticks_data["x"].append({"pixel": px, "value": val, "text": str(val)})
127141

128142
# Align Y values with detected ticks
129-
detected_y_ticks = sorted(ticks["y"])
130-
count = min(len(detected_y_ticks), len(y_values))
131-
for i in range(count):
132-
py = detected_y_ticks[i]
133-
val = y_values[i]
134-
ticks_data["y"].append({"pixel": py, "value": val, "text": str(val)})
143+
# Y pixels sorted descending (largest = bottom of chart)
144+
# OCR values go bottom-to-top: first value is at bottom (largest pixel)
145+
detected_y_ticks = sorted(ticks["y"], reverse=True)
146+
n_ocr_y = len(y_values)
147+
n_detected_y = len(detected_y_ticks)
148+
149+
if n_ocr_y > 0 and n_detected_y >= n_ocr_y:
150+
if n_detected_y == n_ocr_y:
151+
selected_y_ticks = detected_y_ticks
152+
else:
153+
# Skip extra ticks near X-axis (false positives at bottom of chart)
154+
skip = n_detected_y - n_ocr_y
155+
selected_y_ticks = detected_y_ticks[skip:]
156+
157+
for i, py in enumerate(selected_y_ticks):
158+
val = y_values[i]
159+
ticks_data["y"].append({"pixel": py, "value": val, "text": str(val)})
135160

136161
# Calculate confidence
137162
total_ticks = len(ticks["x"]) + len(ticks["y"])

debug_ocr.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
os.environ["MISTRAL_API_KEY"] = "ruQqe2KV9UTYebSxZVrGDf9tIzcEGpbS"
3+
4+
import sys
5+
sys.path.insert(0, "/app")
6+
import cv2
7+
import numpy as np
8+
from chart2csv.core.detection import detect_axes, detect_ticks
9+
from chart2csv.core.ocr import extract_tick_labels
10+
from chart2csv.core.transform import build_transform, apply_transform
11+
12+
# Create a simple test chart
13+
h, w = 400, 600
14+
img = np.ones((h, w), dtype=np.uint8) * 255
15+
16+
# Draw axes
17+
cv2.line(img, (50, 350), (550, 350), 0, 2)
18+
cv2.line(img, (50, 50), (50, 350), 0, 2)
19+
20+
# Draw X labels: 0,1,2,3,4,5
21+
for i, val in enumerate([0, 1, 2, 3, 4, 5]):
22+
x = 50 + i * 100
23+
cv2.line(img, (x, 345), (x, 355), 0, 2)
24+
cv2.putText(img, str(val), (x-5, 375), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 0, 1)
25+
26+
# Draw Y labels: 0,10,20,30,40,50
27+
for i, val in enumerate([0, 10, 20, 30, 40, 50]):
28+
y = 350 - i * 60
29+
cv2.line(img, (45, y), (55, y), 0, 2)
30+
cv2.putText(img, str(val), (10, y+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 0, 1)
31+
32+
axes, conf = detect_axes(img)
33+
x_ax = axes["x"]
34+
y_ax = axes["y"]
35+
print(f"Axes: X at y={x_ax}, Y at x={y_ax}")
36+
37+
ticks, conf = detect_ticks(img, axes)
38+
print(f"X ticks: {sorted(ticks['x'])}")
39+
print(f"Y ticks: {sorted(ticks['y'])}")
40+
41+
ticks_data, ocr_conf = extract_tick_labels(img, axes, use_mistral=True, use_cache=False)
42+
x_ocr = [(t["pixel"], t["value"]) for t in ticks_data["x"]]
43+
y_ocr = [(t["pixel"], t["value"]) for t in ticks_data["y"]]
44+
print(f"OCR X: {x_ocr}")
45+
print(f"OCR Y: {y_ocr}")
46+
47+
if ticks_data["x"] and ticks_data["y"]:
48+
transform, fit_error = build_transform(ticks=ticks_data)
49+
xa = transform["x"]["a"]
50+
xb = transform["x"]["b"]
51+
ya = transform["y"]["a"]
52+
yb = transform["y"]["b"]
53+
print(f"Transform X: a={xa:.6f}, b={xb:.2f}")
54+
print(f"Transform Y: a={ya:.6f}, b={yb:.2f}")
55+
print(f"Fit error: {fit_error:.4f}")
56+
57+
# Test pixel (250, 170) should be approximately (2, 30)
58+
test_px = np.array([[250, 170]])
59+
result = apply_transform(test_px, transform)
60+
rx = result[0,0]
61+
ry = result[0,1]
62+
print(f"Test: pixel (250,170) -> ({rx:.1f}, {ry:.1f})")
63+
print(f"Expected: roughly (2, 30)")

0 commit comments

Comments
 (0)