Skip to content

Commit 2d3212c

Browse files
committed
Stickler construction from JSON Schenma
1 parent 95290ed commit 2d3212c

File tree

1 file changed

+310
-0
lines changed

1 file changed

+310
-0
lines changed
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Bulk evaluation script that uses IDP configuration directly.
4+
5+
This script reads the Stickler evaluation configuration from the IDP config file
6+
(sr_FCC_config.json) instead of requiring a separate stickler_config.json.
7+
"""
8+
9+
import argparse
10+
import json
11+
import sys
12+
from pathlib import Path
13+
from typing import Dict, Any
14+
import pandas as pd
15+
import numpy as np
16+
from collections import defaultdict
17+
18+
# Add lib path for idp_common imports
19+
sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "lib" / "idp_common_pkg"))
20+
21+
from idp_common.evaluation.stickler_service import SticklerEvaluationService
22+
from idp_common.models import Section
23+
24+
25+
def to_json_serializable(obj):
26+
"""Convert numpy types to Python native types."""
27+
if isinstance(obj, (np.bool_, np.integer, np.floating)):
28+
return obj.item()
29+
elif isinstance(obj, np.ndarray):
30+
return obj.tolist()
31+
elif isinstance(obj, dict):
32+
return {k: to_json_serializable(v) for k, v in obj.items()}
33+
elif isinstance(obj, list):
34+
return [to_json_serializable(item) for item in obj]
35+
return obj
36+
37+
38+
def extract_stickler_config_from_idp_config(idp_config: Dict[str, Any]) -> Dict[str, Any]:
39+
"""
40+
Extract Stickler configuration from IDP config JSON Schema.
41+
42+
Args:
43+
idp_config: Full IDP configuration
44+
45+
Returns:
46+
Stickler configuration in the format expected by SticklerEvaluationService
47+
"""
48+
# Get the first class definition (assuming single document type)
49+
if "classes" not in idp_config or not idp_config["classes"]:
50+
raise ValueError("No classes found in IDP configuration")
51+
52+
class_schema = idp_config["classes"][0]
53+
54+
# Extract model name and threshold
55+
model_name = class_schema.get("x-aws-stickler-model-name", "Document")
56+
match_threshold = class_schema.get("x-aws-stickler-match-threshold", 0.7)
57+
58+
# Build fields configuration from properties
59+
fields = {}
60+
properties = class_schema.get("properties", {})
61+
62+
for field_name, field_schema in properties.items():
63+
# Extract Stickler extensions
64+
comparator = field_schema.get("x-aws-stickler-comparator")
65+
threshold = field_schema.get("x-aws-stickler-threshold")
66+
weight = field_schema.get("x-aws-stickler-weight", 1.0)
67+
68+
if comparator: # Only include fields with Stickler configuration
69+
fields[field_name] = {
70+
"type": "list", # All fields are arrays in flat format
71+
"comparator": comparator,
72+
"threshold": threshold,
73+
"weight": weight,
74+
"description": field_schema.get("description", "")
75+
}
76+
77+
return {
78+
"model_name": model_name,
79+
"match_threshold": match_threshold,
80+
"fields": fields
81+
}
82+
83+
84+
def normalize_to_list_format(data: Dict[str, Any]) -> Dict[str, Any]:
85+
"""Normalize data to list format for all fields."""
86+
normalized = {}
87+
for key, value in data.items():
88+
if value is None:
89+
normalized[key] = []
90+
elif isinstance(value, list):
91+
normalized[key] = value
92+
elif isinstance(value, str):
93+
normalized[key] = [value]
94+
else:
95+
normalized[key] = [value]
96+
return normalized
97+
98+
99+
def main():
100+
parser = argparse.ArgumentParser(
101+
description="Bulk evaluate using IDP configuration directly"
102+
)
103+
parser.add_argument("--results-dir", required=True, help="Directory containing inference results")
104+
parser.add_argument("--csv-path", required=True, help="Path to CSV file with ground truth labels")
105+
parser.add_argument("--idp-config-path", required=True, help="Path to IDP configuration JSON (e.g., sr_FCC_config.json)")
106+
parser.add_argument("--doc-id-column", default="doc_id", help="Column name for document IDs")
107+
parser.add_argument("--labels-column", default="refactored_labels", help="Column name for labels")
108+
parser.add_argument("--output-dir", default="evaluation_output", help="Output directory")
109+
110+
args = parser.parse_args()
111+
112+
results_dir = Path(args.results_dir)
113+
output_dir = Path(args.output_dir)
114+
output_dir.mkdir(parents=True, exist_ok=True)
115+
116+
print("=" * 80)
117+
print("BULK FCC INVOICE EVALUATION (IDP Config)")
118+
print("=" * 80)
119+
120+
# Load IDP configuration
121+
print(f"\n📋 Loading IDP config from {args.idp_config_path}...")
122+
with open(args.idp_config_path, 'r') as f:
123+
idp_config = json.load(f)
124+
125+
# Extract Stickler configuration from IDP config
126+
print("📋 Extracting Stickler configuration from IDP config...")
127+
stickler_config = extract_stickler_config_from_idp_config(idp_config)
128+
129+
print(f"✓ Extracted config for model: {stickler_config['model_name']}")
130+
print(f"✓ Found {len(stickler_config['fields'])} fields with Stickler configuration")
131+
132+
# Initialize SticklerEvaluationService
133+
service_config = {
134+
"stickler_models": {
135+
"fcc_invoice": stickler_config
136+
}
137+
}
138+
service = SticklerEvaluationService(config=service_config)
139+
print(f"✓ Service initialized")
140+
141+
# Load ground truth
142+
print(f"\n📊 Loading ground truth from {args.csv_path}...")
143+
df = pd.read_csv(args.csv_path)
144+
df = df[df[args.labels_column].notna()].copy()
145+
print(f"✓ Loaded {len(df)} documents with ground truth")
146+
147+
# Load inference results
148+
print(f"\n📁 Loading inference results from {results_dir}...")
149+
inference_results = {}
150+
for doc_dir in results_dir.iterdir():
151+
if not doc_dir.is_dir():
152+
continue
153+
result_path = doc_dir / "sections" / "1" / "result.json"
154+
if result_path.exists():
155+
with open(result_path, 'r') as f:
156+
result_data = json.load(f)
157+
inference_results[doc_dir.name] = result_data.get("inference_result", {})
158+
print(f"✓ Loaded {len(inference_results)} inference results")
159+
160+
# Match and evaluate
161+
print(f"\n⚙️ Evaluating documents...")
162+
163+
# Accumulation state
164+
overall_metrics = defaultdict(int)
165+
field_metrics = defaultdict(lambda: defaultdict(int))
166+
processed = 0
167+
errors = []
168+
169+
for _, row in df.iterrows():
170+
doc_id = row[args.doc_id_column]
171+
172+
# Find matching result
173+
result_key = None
174+
for key in [doc_id, f"{doc_id}.pdf", doc_id.replace('.pdf', '')]:
175+
if key in inference_results:
176+
result_key = key
177+
break
178+
179+
if not result_key:
180+
continue
181+
182+
try:
183+
# Parse ground truth and get actual results
184+
expected = json.loads(row[args.labels_column])
185+
actual = inference_results[result_key]
186+
187+
# Normalize to list format
188+
expected = normalize_to_list_format(expected)
189+
actual = normalize_to_list_format(actual)
190+
191+
# Create section and evaluate
192+
section = Section(section_id="1", classification="fcc_invoice", page_ids=["1"])
193+
result = service.evaluate_section(section, expected, actual)
194+
195+
# Accumulate metrics from attributes
196+
for attr in result.attributes:
197+
field = attr.name
198+
exp_val = attr.expected
199+
act_val = attr.actual
200+
matched = attr.matched
201+
202+
# Determine metric type
203+
exp_empty = not exp_val or (isinstance(exp_val, list) and len(exp_val) == 0)
204+
act_empty = not act_val or (isinstance(act_val, list) and len(act_val) == 0)
205+
206+
if exp_empty and act_empty:
207+
overall_metrics["tn"] += 1
208+
field_metrics[field]["tn"] += 1
209+
elif exp_empty and not act_empty:
210+
overall_metrics["fp"] += 1
211+
overall_metrics["fp1"] += 1
212+
field_metrics[field]["fp"] += 1
213+
field_metrics[field]["fp1"] += 1
214+
elif not exp_empty and act_empty:
215+
overall_metrics["fn"] += 1
216+
field_metrics[field]["fn"] += 1
217+
elif matched:
218+
overall_metrics["tp"] += 1
219+
field_metrics[field]["tp"] += 1
220+
else:
221+
overall_metrics["fp"] += 1
222+
overall_metrics["fp2"] += 1
223+
field_metrics[field]["fp"] += 1
224+
field_metrics[field]["fp2"] += 1
225+
226+
# Save individual result
227+
result_file = output_dir / f"{doc_id}.json"
228+
result_data = {
229+
"doc_id": doc_id,
230+
"metrics": result.metrics,
231+
"attributes": [
232+
{
233+
"name": a.name,
234+
"expected": a.expected,
235+
"actual": a.actual,
236+
"matched": a.matched,
237+
"score": float(a.score),
238+
"reason": a.reason
239+
}
240+
for a in result.attributes
241+
]
242+
}
243+
with open(result_file, 'w') as f:
244+
json.dump(to_json_serializable(result_data), f, indent=2)
245+
246+
processed += 1
247+
248+
except Exception as e:
249+
errors.append({"doc_id": doc_id, "error": str(e)})
250+
print(f" ✗ Error evaluating {doc_id}: {e}")
251+
252+
print(f"✓ Completed evaluation of {processed} documents")
253+
254+
# Calculate metrics
255+
def calc_metrics(cm):
256+
tp, fp, tn, fn = cm["tp"], cm["fp"], cm["tn"], cm["fn"]
257+
total = tp + fp + tn + fn
258+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
259+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
260+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
261+
accuracy = (tp + tn) / total if total > 0 else 0.0
262+
return {
263+
"precision": precision, "recall": recall, "f1_score": f1, "accuracy": accuracy,
264+
"tp": tp, "fp": fp, "tn": tn, "fn": fn,
265+
"fp1": cm["fp1"], "fp2": cm["fp2"], "total": total
266+
}
267+
268+
overall = calc_metrics(overall_metrics)
269+
fields = {field: calc_metrics(cm) for field, cm in field_metrics.items()}
270+
271+
# Print results
272+
print("\n" + "=" * 80)
273+
print("AGGREGATED RESULTS")
274+
print("=" * 80)
275+
print(f"\n📊 Summary: {processed} processed, {len(errors)} errors")
276+
print(f"\n📈 Overall Metrics:")
277+
print(f" Precision: {overall['precision']:.4f}")
278+
print(f" Recall: {overall['recall']:.4f}")
279+
print(f" F1 Score: {overall['f1_score']:.4f}")
280+
print(f" Accuracy: {overall['accuracy']:.4f}")
281+
print(f"\n Confusion Matrix:")
282+
print(f" TP: {overall['tp']:6d} | FP: {overall['fp']:6d}")
283+
print(f" FN: {overall['fn']:6d} | TN: {overall['tn']:6d}")
284+
print(f" FP1: {overall['fp1']:6d} | FP2: {overall['fp2']:6d}")
285+
286+
# Top fields
287+
sorted_fields = sorted(fields.items(), key=lambda x: x[1]["f1_score"], reverse=True)
288+
print(f"\n📋 Field-Level Metrics (Top 10):")
289+
print(f" {'Field':<40} {'Precision':>10} {'Recall':>10} {'F1':>10}")
290+
print(f" {'-'*40} {'-'*10} {'-'*10} {'-'*10}")
291+
for field, metrics in sorted_fields[:10]:
292+
print(f" {field:<40} {metrics['precision']:>10.4f} {metrics['recall']:>10.4f} {metrics['f1_score']:>10.4f}")
293+
294+
# Save aggregated results
295+
output_file = output_dir / "aggregated_metrics.json"
296+
with open(output_file, 'w') as f:
297+
json.dump({
298+
"summary": {"documents_processed": processed, "errors": len(errors)},
299+
"overall_metrics": overall,
300+
"field_metrics": fields,
301+
"errors": errors,
302+
"stickler_config_used": stickler_config
303+
}, f, indent=2)
304+
305+
print(f"\n💾 Results saved to {output_dir}")
306+
print("=" * 80)
307+
308+
309+
if __name__ == "__main__":
310+
main()

0 commit comments

Comments
 (0)