Skip to content

Commit 73f1a3d

Browse files
committed
Adding the stickler service.
1 parent e433829 commit 73f1a3d

File tree

5 files changed

+1186
-0
lines changed

5 files changed

+1186
-0
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test SticklerEvaluationService with actual FCC invoice data.
4+
5+
This script demonstrates using SticklerEvaluationService to evaluate
6+
real FCC invoice extraction results against ground truth labels.
7+
"""
8+
9+
import json
10+
import os
11+
import pandas as pd
12+
from pathlib import Path
13+
from idp_common.evaluation import SticklerEvaluationService
14+
from idp_common.models import Section
15+
16+
17+
def load_ground_truth_from_csv(csv_path: str, doc_id: str):
18+
"""
19+
Load ground truth labels from the refactored labels CSV.
20+
21+
Args:
22+
csv_path: Path to the CSV file with refactored_labels column
23+
doc_id: Document ID to look up
24+
25+
Returns:
26+
Dictionary of ground truth labels
27+
"""
28+
df = pd.read_csv(csv_path)
29+
30+
# Find the row for this document
31+
row = df[df['doc_id'] == doc_id]
32+
33+
if row.empty:
34+
print(f"Warning: No ground truth found for doc_id: {doc_id}")
35+
return None
36+
37+
# Parse the refactored_labels JSON
38+
labels_json = row['refactored_labels'].values[0]
39+
40+
if pd.isna(labels_json):
41+
print(f"Warning: refactored_labels is empty for doc_id: {doc_id}")
42+
return None
43+
44+
try:
45+
labels = json.loads(labels_json)
46+
return labels
47+
except json.JSONDecodeError as e:
48+
print(f"Error parsing refactored_labels JSON: {e}")
49+
return None
50+
51+
52+
def create_fcc_stickler_config():
53+
"""
54+
Create a Stickler configuration for FCC invoices.
55+
56+
Returns:
57+
Configuration dictionary for SticklerEvaluationService
58+
"""
59+
config = {
60+
"stickler_models": {
61+
"fcc-invoice": {
62+
"model_name": "FCCInvoice",
63+
"match_threshold": 0.7,
64+
"fields": {
65+
"agency": {
66+
"type": "str",
67+
"comparator": "FuzzyComparator",
68+
"threshold": 0.8,
69+
"weight": 2.0,
70+
},
71+
"advertiser": {
72+
"type": "str",
73+
"comparator": "FuzzyComparator",
74+
"threshold": 0.8,
75+
"weight": 2.0,
76+
},
77+
"gross_total": {
78+
"type": "str", # Stored as string with commas
79+
"comparator": "ExactComparator",
80+
"threshold": 1.0,
81+
"weight": 3.0,
82+
},
83+
"net_amount_due": {
84+
"type": "str", # Stored as string with commas
85+
"comparator": "ExactComparator",
86+
"threshold": 1.0,
87+
"weight": 3.0,
88+
},
89+
"line_item__description": {
90+
"type": "list",
91+
"comparator": "LevenshteinComparator",
92+
"threshold": 0.7,
93+
"weight": 1.5,
94+
},
95+
"line_item__days": {
96+
"type": "list",
97+
"comparator": "ExactComparator",
98+
"threshold": 1.0,
99+
"weight": 1.0,
100+
},
101+
"line_item__rate": {
102+
"type": "list",
103+
"comparator": "ExactComparator",
104+
"threshold": 1.0,
105+
"weight": 2.0,
106+
},
107+
"line_item__start_date": {
108+
"type": "list",
109+
"comparator": "ExactComparator",
110+
"threshold": 1.0,
111+
"weight": 2.0,
112+
},
113+
"line_item__end_date": {
114+
"type": "list",
115+
"comparator": "ExactComparator",
116+
"threshold": 1.0,
117+
"weight": 2.0,
118+
},
119+
},
120+
}
121+
}
122+
}
123+
124+
return config
125+
126+
127+
def main():
128+
"""Run the FCC invoice evaluation test."""
129+
130+
print("=" * 80)
131+
print("SticklerEvaluationService - FCC Invoice Data Test")
132+
print("=" * 80)
133+
134+
# Paths
135+
csv_path = "sr_refactor_labels_5_5_25.csv"
136+
data_dir = "tmp_data/cli-batch-20251017-154358"
137+
138+
# Check if paths exist
139+
if not os.path.exists(csv_path):
140+
print(f"Error: CSV file not found: {csv_path}")
141+
return
142+
143+
if not os.path.exists(data_dir):
144+
print(f"Error: Data directory not found: {data_dir}")
145+
return
146+
147+
# Create Stickler configuration
148+
print("\n1. Creating Stickler configuration for FCC invoices...")
149+
config = create_fcc_stickler_config()
150+
print(" ✓ Configuration created")
151+
152+
# Initialize service
153+
print("\n2. Initializing SticklerEvaluationService...")
154+
service = SticklerEvaluationService(config=config)
155+
print(f" ✓ Service initialized with models: {list(service.stickler_models.keys())}")
156+
157+
# Find a sample document to test
158+
doc_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
159+
160+
if not doc_dirs:
161+
print("Error: No document directories found")
162+
return
163+
164+
# Use the first document
165+
sample_doc = doc_dirs[0]
166+
doc_path = os.path.join(data_dir, sample_doc)
167+
168+
print(f"\n3. Testing with document: {sample_doc}")
169+
170+
# Load the extraction result
171+
result_path = os.path.join(doc_path, "sections/1/result.json")
172+
173+
if not os.path.exists(result_path):
174+
print(f"Error: Result file not found: {result_path}")
175+
return
176+
177+
with open(result_path, 'r') as f:
178+
result_data = json.load(f)
179+
180+
actual_results = result_data.get('inference_result', {})
181+
doc_class = result_data.get('document_class', {}).get('type', 'unknown')
182+
183+
print(f" Document class: {doc_class}")
184+
print(f" Extracted fields: {list(actual_results.keys())}")
185+
186+
# Load ground truth from CSV
187+
# Extract doc_id from filename (remove .pdf extension)
188+
doc_id_from_filename = sample_doc.replace('.pdf', '')
189+
190+
print(f"\n4. Loading ground truth for doc_id: {doc_id_from_filename}")
191+
ground_truth = load_ground_truth_from_csv(csv_path, doc_id_from_filename)
192+
193+
if ground_truth is None:
194+
print(" Warning: No ground truth available, using actual results as expected")
195+
print(" This will show perfect matches (for demonstration purposes)")
196+
expected_results = actual_results
197+
else:
198+
expected_results = ground_truth
199+
print(f" ✓ Ground truth loaded with {len(expected_results)} fields")
200+
201+
# Create a section
202+
section = Section(
203+
section_id="section1",
204+
classification="fcc-invoice",
205+
page_ids=["page1"]
206+
)
207+
208+
# Evaluate
209+
print("\n5. Evaluating extraction results...")
210+
try:
211+
result = service.evaluate_section(
212+
section=section,
213+
expected_results=expected_results,
214+
actual_results=actual_results
215+
)
216+
217+
print(" ✓ Evaluation completed")
218+
219+
# Display results
220+
print("\n6. Evaluation Results")
221+
print("-" * 80)
222+
print(f"Section ID: {result.section_id}")
223+
print(f"Document Class: {result.document_class}")
224+
225+
if result.metrics:
226+
print(f"\nMetrics:")
227+
for metric_name, metric_value in result.metrics.items():
228+
print(f" {metric_name:25} {metric_value:.4f}")
229+
230+
if result.attributes:
231+
print(f"\nAttribute Results ({len(result.attributes)} attributes):")
232+
print(f"{'Attribute':<30} {'Match':<8} {'Score':<8}")
233+
print("-" * 50)
234+
235+
matched_count = 0
236+
for attr in result.attributes[:20]: # Show first 20
237+
match_symbol = "✓" if attr.matched else "✗"
238+
if attr.matched:
239+
matched_count += 1
240+
print(f"{attr.name:<30} {match_symbol:<8} {attr.score:<8.3f}")
241+
242+
if len(result.attributes) > 20:
243+
print(f"... and {len(result.attributes) - 20} more attributes")
244+
245+
print(f"\nSummary: {matched_count}/{len(result.attributes)} attributes matched")
246+
else:
247+
print("\nNo attributes evaluated (model may not be configured for this class)")
248+
249+
except Exception as e:
250+
print(f" ✗ Error during evaluation: {str(e)}")
251+
import traceback
252+
traceback.print_exc()
253+
254+
print("\n" + "=" * 80)
255+
print("Test completed!")
256+
print("=" * 80)
257+
258+
259+
if __name__ == "__main__":
260+
main()

0 commit comments

Comments
 (0)