Skip to content

Commit 38b3329

Browse files
Merge pull request #97 from WecoAI/line-plot-example
Add extract-line-plot example
2 parents 57a9408 + e8b84e1 commit 38b3329

File tree

3 files changed

+138
-101
lines changed

3 files changed

+138
-101
lines changed

examples/extract-line-plot/README.md

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
## Extract Line Plot (Chart → CSV) with a VLM
1+
## Extract Line Plot (Chart → CSV): Accuracy/Cost Optimization for Agentic Workflow
22

3-
This example is about optimizing an AI feature that turns image of chart into a table in csv format.
3+
This example demonstrates optimizing an AI feature that turns chart images into CSV tables, showcasing how to use Weco to improve accuracy or reduce cost of a VLM-based extraction workflow.
44

55
### Prerequisites
66

@@ -15,8 +15,9 @@ export OPENAI_API_KEY=your_key_here
1515
### Files
1616

1717
- `prepare_data.py`: downloads ChartQA (full) and prepares a 100-sample subset of line charts.
18-
- `optimize.py`: baseline VLM function (`VLMExtractor.image_to_csv`) to be optimized.
18+
- `optimize.py`: exposes `extract_csv(image_path)` which returns CSV text plus the per-call cost (helpers stay private).
1919
- `eval.py`: evaluation harness that runs the baseline on images and reports a similarity score as "accuracy".
20+
- `guide.md`: optional additional instructions you can feed to Weco via `--additional-instructions guide.md`.
2021

2122
Generated artifacts (gitignored):
2223
- `subset_line_100/` and `subset_line_100.zip`
@@ -47,12 +48,21 @@ Metric definition (summarized):
4748
- Per-sample score = 0.2 × header match + 0.8 × Jaccard(similarity of content rows).
4849
- Reported `accuracy` is the mean score over all evaluated samples.
4950

51+
To emit a secondary `cost` metric that Weco can minimize (while enforcing `accuracy > 0.45`), append `--cost-metric`:
52+
53+
```bash
54+
uv run --with openai python eval.py --max-samples 10 --num-workers 4 --cost-metric
55+
```
56+
57+
If the final accuracy falls at or below `0.45`, the reported cost is replaced with a large penalty so Weco keeps searching for higher-accuracy solutions.
58+
You can tighten or relax this constraint with `--cost-accuracy-threshold`, e.g. `--cost-accuracy-threshold 0.50`.
59+
5060
### 3) Optimize the baseline with Weco
5161

5262
Run Weco to iteratively improve `optimize.py` using 100 examples and many workers:
5363

5464
```bash
55-
weco run --source optimize.py --eval-command 'uv run --with openai python eval.py --max-samples 100 --num-workers 50' --metric accuracy --goal maximize --steps 20 --model gpt-5
65+
weco run --source optimize.py --eval-command 'uv run --with openai python eval.py --max-samples 100 --num-workers 50' --metric accuracy --goal maximize --steps 20 --model gpt-5 --additional-instructions guide.md
5666
```
5767

5868
Arguments:
@@ -63,10 +73,20 @@ Arguments:
6373
- `--steps 20`: number of optimization iterations.
6474
- `--model gpt-5`: model used by Weco to propose edits (change as desired).
6575

76+
To minimize cost instead (subject to the accuracy constraint), enable the flag in the eval command and switch the optimization target:
77+
78+
```bash
79+
weco run --source optimize.py --eval-command 'uv run --with openai python eval.py --max-samples 100 --num-workers 50 --cost-metric' --metric cost --goal minimize --steps 20 --model gpt-5 --additional-instructions guide.md
80+
```
81+
82+
#### Cost optimization workflow
83+
- Run the evaluation command with `--cost-metric` once to confirm accuracy meets your threshold and note the baseline cost.
84+
- Adjust `--cost-accuracy-threshold` if you want to tighten or relax the constraint before launching optimization.
85+
- Kick off Weco with `--metric cost --goal minimize --additional-instructions guide.md` so the optimizer respects the constraint while acting on the extra tips.
86+
6687
### Tips
6788

6889
- Ensure your OpenAI key has access to a vision-capable model (default: `gpt-4o-mini` in the eval; change via `--model`).
6990
- Adjust `--num-workers` to balance throughput and rate limits.
7091
- You can tweak baseline behavior in `optimize.py` (prompt, temperature) — Weco will explore modifications automatically during optimization.
71-
72-
92+
- Include `--additional-instructions guide.md` whenever you run Weco so those cost-conscious hints influence the generated proposals.

examples/extract-line-plot/eval.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pathlib import Path
99
from typing import Dict, List, Optional, Tuple
1010

11-
from optimize import VLMExtractor
11+
from optimize import extract_csv
1212

1313
try:
1414
import matplotlib
@@ -18,6 +18,9 @@
1818
except Exception: # pragma: no cover - optional dependency
1919
plt = None
2020

21+
COST_ACCURACY_THRESHOLD_DEFAULT = 0.45
22+
COST_CONSTRAINT_PENALTY = 1_000_000.0
23+
2124

2225
def read_index(index_csv_path: Path) -> List[Tuple[str, Path, Path]]:
2326
rows: List[Tuple[str, Path, Path]] = []
@@ -259,14 +262,14 @@ def evaluate_predictions(gt_csv_path: Path, pred_csv_path: Path) -> float:
259262

260263

261264
def process_one(
262-
extractor: VLMExtractor, base_dir: Path, example_id: str, image_rel: Path, gt_table_rel: Path, output_dir: Path
263-
) -> Tuple[str, float, Path, Path]:
265+
base_dir: Path, example_id: str, image_rel: Path, gt_table_rel: Path, output_dir: Path
266+
) -> Tuple[str, float, Path, Path, float]:
264267
image_path = base_dir / image_rel
265268
gt_csv_path = base_dir / gt_table_rel
266-
pred_csv_text = extractor.image_to_csv(image_path)
269+
pred_csv_text, cost_usd = extract_csv(image_path)
267270
pred_path = write_csv(output_dir, example_id, pred_csv_text)
268271
score = evaluate_predictions(gt_csv_path, pred_path)
269-
return example_id, score, pred_path, gt_csv_path
272+
return example_id, score, pred_path, gt_csv_path, cost_usd
270273

271274

272275
def main() -> None:
@@ -276,6 +279,20 @@ def main() -> None:
276279
parser.add_argument("--out-dir", type=str, default="predictions")
277280
parser.add_argument("--max-samples", type=int, default=100)
278281
parser.add_argument("--num-workers", type=int, default=4)
282+
parser.add_argument(
283+
"--cost-metric",
284+
action="store_true",
285+
help=(
286+
"When set, also report a `cost:` metric suitable for Weco minimization. "
287+
"Requires final accuracy to exceed --cost-accuracy-threshold; otherwise a large penalty is reported."
288+
),
289+
)
290+
parser.add_argument(
291+
"--cost-accuracy-threshold",
292+
type=float,
293+
default=COST_ACCURACY_THRESHOLD_DEFAULT,
294+
help="Minimum accuracy required when --cost-metric is set (default: 0.45).",
295+
)
279296
parser.add_argument(
280297
"--visualize-dir",
281298
type=str,
@@ -307,30 +324,31 @@ def main() -> None:
307324
sys.exit(1)
308325

309326
rows = read_index(index_path)[: args.max_samples]
310-
extractor = VLMExtractor()
311327

312328
visualize_dir: Optional[Path] = Path(args.visualize_dir) if args.visualize_dir else None
313329
visualize_max = max(0, args.visualize_max)
314330
if visualize_dir and plt is None:
315331
print("[warn] matplotlib not available; skipping visualization.", file=sys.stderr)
316332
visualize_dir = None
317333

318-
print(f"[setup] evaluating {len(rows)} samples using {extractor.model} …", flush=True)
334+
print(f"[setup] evaluating {len(rows)} samples …", flush=True)
319335
start = time.time()
320336
scores: List[float] = []
337+
costs: List[float] = []
321338
saved_visualizations = 0
322339

323340
with ThreadPoolExecutor(max_workers=max(1, args.num_workers)) as pool:
324341
futures = [
325-
pool.submit(process_one, extractor, base_dir, example_id, image_rel, gt_table_rel, Path(args.out_dir))
342+
pool.submit(process_one, base_dir, example_id, image_rel, gt_table_rel, Path(args.out_dir))
326343
for (example_id, image_rel, gt_table_rel) in rows
327344
]
328345

329346
try:
330347
for idx, fut in enumerate(as_completed(futures), 1):
331348
try:
332-
example_id, score, pred_path, gt_csv_path = fut.result()
349+
example_id, score, pred_path, gt_csv_path, cost_usd = fut.result()
333350
scores.append(score)
351+
costs.append(cost_usd)
334352
if visualize_dir and (visualize_max == 0 or saved_visualizations < visualize_max):
335353
out_path = visualize_difference(
336354
gt_csv_path,
@@ -346,7 +364,11 @@ def main() -> None:
346364
if idx % 5 == 0 or idx == len(rows):
347365
elapsed = time.time() - start
348366
avg = sum(scores) / len(scores) if scores else 0.0
349-
print(f"[progress] {idx}/{len(rows)} done, avg score: {avg:.4f}, elapsed {elapsed:.1f}s", flush=True)
367+
avg_cost = sum(costs) / len(costs) if costs else 0.0
368+
print(
369+
f"[progress] {idx}/{len(rows)} done, avg score: {avg:.4f}, avg cost: ${avg_cost:.4f}, elapsed {elapsed:.1f}s",
370+
flush=True,
371+
)
350372
except Exception as e:
351373
print(f"[error] failed on sample {idx}: {e}", file=sys.stderr)
352374
except KeyboardInterrupt:
@@ -356,7 +378,7 @@ def main() -> None:
356378
final_score = sum(scores) / len(scores) if scores else 0.0
357379

358380
# Apply cost cap: accuracy is zeroed if average cost/query exceeds $0.02
359-
avg_cost_per_query = (extractor.total_cost_usd / extractor.num_queries) if getattr(extractor, "num_queries", 0) else 0.0
381+
avg_cost_per_query = (sum(costs) / len(costs)) if costs else 0.0
360382
if avg_cost_per_query > 0.02:
361383
print(f"[cost] avg ${avg_cost_per_query:.4f}/query exceeds $0.02 cap; accuracy set to 0.0", flush=True)
362384
final_score = 0.0
@@ -365,6 +387,20 @@ def main() -> None:
365387

366388
print(f"accuracy: {final_score:.4f}")
367389

390+
if args.cost_metric:
391+
if final_score > args.cost_accuracy_threshold:
392+
reported_cost = avg_cost_per_query
393+
else:
394+
print(
395+
(
396+
f"[constraint] accuracy {final_score:.4f} <= "
397+
f"threshold {args.cost_accuracy_threshold:.2f}; reporting penalty ${COST_CONSTRAINT_PENALTY:.1f}"
398+
),
399+
flush=True,
400+
)
401+
reported_cost = COST_CONSTRAINT_PENALTY
402+
print(f"cost: {reported_cost:.6f}")
403+
368404

369405
if __name__ == "__main__":
370406
main()
Lines changed: 65 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
"""
22
optimize.py
33
4-
Baseline implementation of a VLM-driven function that takes an image and returns CSV.
5-
Weco will optimize the prompt and logic here.
4+
Exposes a single public entry point `extract_csv` that turns a chart image into CSV text.
5+
All helper utilities remain private to this module.
66
"""
77

88
import base64
9-
import threading
109
from pathlib import Path
11-
from typing import Optional
10+
from typing import Optional, Tuple
1211

1312
from openai import OpenAI
1413

14+
__all__ = ["extract_csv"]
1515

16-
def build_prompt() -> str:
16+
_DEFAULT_MODEL = "gpt-4o-mini"
17+
_CLIENT = OpenAI()
18+
19+
20+
def _build_prompt() -> str:
1721
return (
1822
"You are a precise data extraction model. Given a chart image, extract the underlying data table.\n"
1923
"Return ONLY the CSV text with a header row and no markdown code fences.\n"
@@ -25,92 +29,69 @@ def build_prompt() -> str:
2529
)
2630

2731

28-
def image_to_data_uri(image_path: Path) -> str:
32+
def _image_to_data_uri(image_path: Path) -> str:
2933
mime = "image/png" if image_path.suffix.lower() == ".png" else "image/jpeg"
3034
data = image_path.read_bytes()
3135
b64 = base64.b64encode(data).decode("ascii")
3236
return f"data:{mime};base64,{b64}"
3337

3438

35-
def clean_to_csv(text: str) -> str:
39+
def _clean_to_csv(text: str) -> str:
3640
return text.strip()
3741

3842

39-
class VLMExtractor:
40-
"""Baseline VLM wrapper for chart-to-CSV extraction."""
41-
42-
def __init__(self, model: str = "gpt-4o-mini", client: Optional[OpenAI] = None) -> None:
43-
self.model = model
44-
self.client = client or OpenAI()
45-
# Aggregates
46-
self.total_prompt_tokens: int = 0
47-
self.total_completion_tokens: int = 0
48-
self.total_cost_usd: float = 0.0
49-
self.num_queries: int = 0
50-
self._usage_lock = threading.Lock()
51-
52-
def _pricing_for_model(self) -> dict:
53-
"""Return pricing for current model in USD per token.
54-
55-
Structure: {"in": x, "in_cached": y, "out": z}
56-
Defaults to GPT-5 mini if model not matched.
57-
"""
58-
name = (self.model or "").lower()
59-
# Prices are given per 1M tokens in the spec; convert to per-token
60-
per_million = {
61-
"gpt-5": {"in": 1.250, "in_cached": 0.125, "out": 10.000},
62-
"gpt-5-mini": {"in": 0.250, "in_cached": 0.025, "out": 2.000},
63-
"gpt-5-nano": {"in": 0.050, "in_cached": 0.005, "out": 0.400},
64-
}
65-
# Pick by prefix
66-
if name.startswith("gpt-5-nano"):
67-
chosen = per_million["gpt-5-nano"]
68-
elif name.startswith("gpt-5-mini"):
69-
chosen = per_million["gpt-5-mini"]
70-
elif name.startswith("gpt-5"):
71-
chosen = per_million["gpt-5"]
72-
else:
73-
chosen = per_million["gpt-5-mini"]
74-
# Convert per 1M to per token
75-
return {k: v / 1_000_000.0 for k, v in chosen.items()}
76-
77-
def image_to_csv(self, image_path: Path) -> str:
78-
prompt = build_prompt()
79-
image_uri = image_to_data_uri(image_path)
80-
response = self.client.chat.completions.create(
81-
model=self.model,
82-
messages=[
83-
{
84-
"role": "user",
85-
"content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_uri}}],
86-
}
87-
],
43+
def _pricing_for_model(model_name: str) -> dict:
44+
"""Return pricing information for the given model in USD per token."""
45+
name = (model_name or "").lower()
46+
per_million = {
47+
"gpt-5": {"in": 1.250, "in_cached": 0.125, "out": 10.000},
48+
"gpt-5-mini": {"in": 0.250, "in_cached": 0.025, "out": 2.000},
49+
"gpt-5-nano": {"in": 0.050, "in_cached": 0.005, "out": 0.400},
50+
}
51+
if name.startswith("gpt-5-nano"):
52+
chosen = per_million["gpt-5-nano"]
53+
elif name.startswith("gpt-5-mini"):
54+
chosen = per_million["gpt-5-mini"]
55+
elif name.startswith("gpt-5"):
56+
chosen = per_million["gpt-5"]
57+
else:
58+
chosen = per_million["gpt-5-mini"]
59+
return {k: v / 1_000_000.0 for k, v in chosen.items()}
60+
61+
62+
def extract_csv(image_path: Path, model: Optional[str] = None) -> Tuple[str, float]:
63+
"""
64+
Extract CSV text from an image and return (csv_text, cost_usd).
65+
66+
The caller can optionally override the model name; otherwise the default is used.
67+
"""
68+
effective_model = model or _DEFAULT_MODEL
69+
prompt = _build_prompt()
70+
image_uri = _image_to_data_uri(image_path)
71+
response = _CLIENT.chat.completions.create(
72+
model=effective_model,
73+
messages=[
74+
{
75+
"role": "user",
76+
"content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_uri}}],
77+
}
78+
],
79+
)
80+
81+
usage = getattr(response, "usage", None)
82+
cost_usd = 0.0
83+
if usage is not None:
84+
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
85+
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
86+
details = getattr(usage, "prompt_tokens_details", None)
87+
cached_tokens = 0
88+
if details is not None:
89+
cached_tokens = int(getattr(details, "cached_tokens", 0) or 0)
90+
non_cached_prompt_tokens = max(0, prompt_tokens - cached_tokens)
91+
rates = _pricing_for_model(effective_model)
92+
cost_usd = (
93+
non_cached_prompt_tokens * rates["in"] + cached_tokens * rates["in_cached"] + completion_tokens * rates["out"]
8894
)
89-
# Track usage and cost if available
90-
usage = getattr(response, "usage", None)
91-
with self._usage_lock:
92-
if usage is not None:
93-
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
94-
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
95-
# Attempt to detect cached tokens if available
96-
details = getattr(usage, "prompt_tokens_details", None)
97-
cached_tokens = 0
98-
if details is not None:
99-
cached_tokens = int(getattr(details, "cached_tokens", 0) or 0)
100-
non_cached_prompt_tokens = max(0, prompt_tokens - cached_tokens)
101-
102-
rates = self._pricing_for_model()
103-
cost = (
104-
non_cached_prompt_tokens * rates["in"]
105-
+ cached_tokens * rates["in_cached"]
106-
+ completion_tokens * rates["out"]
107-
)
108-
109-
self.total_prompt_tokens += prompt_tokens
110-
self.total_completion_tokens += completion_tokens
111-
self.total_cost_usd += cost
112-
self.num_queries += 1
113-
else:
114-
self.num_queries += 1
115-
text = response.choices[0].message.content or ""
116-
return clean_to_csv(text)
95+
96+
text = response.choices[0].message.content or ""
97+
return _clean_to_csv(text), cost_usd

0 commit comments

Comments
 (0)