|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Bulk evaluation script that uses IDP configuration directly. |
| 4 | +
|
| 5 | +This script reads the Stickler evaluation configuration from the IDP config file |
| 6 | +(sr_FCC_config.json) instead of requiring a separate stickler_config.json. |
| 7 | +""" |
| 8 | + |
| 9 | +import argparse |
| 10 | +import json |
| 11 | +import sys |
| 12 | +from pathlib import Path |
| 13 | +from typing import Dict, Any |
| 14 | +import pandas as pd |
| 15 | +import numpy as np |
| 16 | +from collections import defaultdict |
| 17 | + |
| 18 | +# Add lib path for idp_common imports |
| 19 | +sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "lib" / "idp_common_pkg")) |
| 20 | + |
| 21 | +from idp_common.evaluation.stickler_service import SticklerEvaluationService |
| 22 | +from idp_common.models import Section |
| 23 | + |
| 24 | + |
| 25 | +def to_json_serializable(obj): |
| 26 | + """Convert numpy types to Python native types.""" |
| 27 | + if isinstance(obj, (np.bool_, np.integer, np.floating)): |
| 28 | + return obj.item() |
| 29 | + elif isinstance(obj, np.ndarray): |
| 30 | + return obj.tolist() |
| 31 | + elif isinstance(obj, dict): |
| 32 | + return {k: to_json_serializable(v) for k, v in obj.items()} |
| 33 | + elif isinstance(obj, list): |
| 34 | + return [to_json_serializable(item) for item in obj] |
| 35 | + return obj |
| 36 | + |
| 37 | + |
| 38 | +def extract_stickler_config_from_idp_config(idp_config: Dict[str, Any]) -> Dict[str, Any]: |
| 39 | + """ |
| 40 | + Extract Stickler configuration from IDP config JSON Schema. |
| 41 | + |
| 42 | + Args: |
| 43 | + idp_config: Full IDP configuration |
| 44 | + |
| 45 | + Returns: |
| 46 | + Stickler configuration in the format expected by SticklerEvaluationService |
| 47 | + """ |
| 48 | + # Get the first class definition (assuming single document type) |
| 49 | + if "classes" not in idp_config or not idp_config["classes"]: |
| 50 | + raise ValueError("No classes found in IDP configuration") |
| 51 | + |
| 52 | + class_schema = idp_config["classes"][0] |
| 53 | + |
| 54 | + # Extract model name and threshold |
| 55 | + model_name = class_schema.get("x-aws-stickler-model-name", "Document") |
| 56 | + match_threshold = class_schema.get("x-aws-stickler-match-threshold", 0.7) |
| 57 | + |
| 58 | + # Build fields configuration from properties |
| 59 | + fields = {} |
| 60 | + properties = class_schema.get("properties", {}) |
| 61 | + |
| 62 | + for field_name, field_schema in properties.items(): |
| 63 | + # Extract Stickler extensions |
| 64 | + comparator = field_schema.get("x-aws-stickler-comparator") |
| 65 | + threshold = field_schema.get("x-aws-stickler-threshold") |
| 66 | + weight = field_schema.get("x-aws-stickler-weight", 1.0) |
| 67 | + |
| 68 | + if comparator: # Only include fields with Stickler configuration |
| 69 | + fields[field_name] = { |
| 70 | + "type": "list", # All fields are arrays in flat format |
| 71 | + "comparator": comparator, |
| 72 | + "threshold": threshold, |
| 73 | + "weight": weight, |
| 74 | + "description": field_schema.get("description", "") |
| 75 | + } |
| 76 | + |
| 77 | + return { |
| 78 | + "model_name": model_name, |
| 79 | + "match_threshold": match_threshold, |
| 80 | + "fields": fields |
| 81 | + } |
| 82 | + |
| 83 | + |
| 84 | +def normalize_to_list_format(data: Dict[str, Any]) -> Dict[str, Any]: |
| 85 | + """Normalize data to list format for all fields.""" |
| 86 | + normalized = {} |
| 87 | + for key, value in data.items(): |
| 88 | + if value is None: |
| 89 | + normalized[key] = [] |
| 90 | + elif isinstance(value, list): |
| 91 | + normalized[key] = value |
| 92 | + elif isinstance(value, str): |
| 93 | + normalized[key] = [value] |
| 94 | + else: |
| 95 | + normalized[key] = [value] |
| 96 | + return normalized |
| 97 | + |
| 98 | + |
| 99 | +def main(): |
| 100 | + parser = argparse.ArgumentParser( |
| 101 | + description="Bulk evaluate using IDP configuration directly" |
| 102 | + ) |
| 103 | + parser.add_argument("--results-dir", required=True, help="Directory containing inference results") |
| 104 | + parser.add_argument("--csv-path", required=True, help="Path to CSV file with ground truth labels") |
| 105 | + parser.add_argument("--idp-config-path", required=True, help="Path to IDP configuration JSON (e.g., sr_FCC_config.json)") |
| 106 | + parser.add_argument("--doc-id-column", default="doc_id", help="Column name for document IDs") |
| 107 | + parser.add_argument("--labels-column", default="refactored_labels", help="Column name for labels") |
| 108 | + parser.add_argument("--output-dir", default="evaluation_output", help="Output directory") |
| 109 | + |
| 110 | + args = parser.parse_args() |
| 111 | + |
| 112 | + results_dir = Path(args.results_dir) |
| 113 | + output_dir = Path(args.output_dir) |
| 114 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 115 | + |
| 116 | + print("=" * 80) |
| 117 | + print("BULK FCC INVOICE EVALUATION (IDP Config)") |
| 118 | + print("=" * 80) |
| 119 | + |
| 120 | + # Load IDP configuration |
| 121 | + print(f"\n📋 Loading IDP config from {args.idp_config_path}...") |
| 122 | + with open(args.idp_config_path, 'r') as f: |
| 123 | + idp_config = json.load(f) |
| 124 | + |
| 125 | + # Extract Stickler configuration from IDP config |
| 126 | + print("📋 Extracting Stickler configuration from IDP config...") |
| 127 | + stickler_config = extract_stickler_config_from_idp_config(idp_config) |
| 128 | + |
| 129 | + print(f"✓ Extracted config for model: {stickler_config['model_name']}") |
| 130 | + print(f"✓ Found {len(stickler_config['fields'])} fields with Stickler configuration") |
| 131 | + |
| 132 | + # Initialize SticklerEvaluationService |
| 133 | + service_config = { |
| 134 | + "stickler_models": { |
| 135 | + "fcc_invoice": stickler_config |
| 136 | + } |
| 137 | + } |
| 138 | + service = SticklerEvaluationService(config=service_config) |
| 139 | + print(f"✓ Service initialized") |
| 140 | + |
| 141 | + # Load ground truth |
| 142 | + print(f"\n📊 Loading ground truth from {args.csv_path}...") |
| 143 | + df = pd.read_csv(args.csv_path) |
| 144 | + df = df[df[args.labels_column].notna()].copy() |
| 145 | + print(f"✓ Loaded {len(df)} documents with ground truth") |
| 146 | + |
| 147 | + # Load inference results |
| 148 | + print(f"\n📁 Loading inference results from {results_dir}...") |
| 149 | + inference_results = {} |
| 150 | + for doc_dir in results_dir.iterdir(): |
| 151 | + if not doc_dir.is_dir(): |
| 152 | + continue |
| 153 | + result_path = doc_dir / "sections" / "1" / "result.json" |
| 154 | + if result_path.exists(): |
| 155 | + with open(result_path, 'r') as f: |
| 156 | + result_data = json.load(f) |
| 157 | + inference_results[doc_dir.name] = result_data.get("inference_result", {}) |
| 158 | + print(f"✓ Loaded {len(inference_results)} inference results") |
| 159 | + |
| 160 | + # Match and evaluate |
| 161 | + print(f"\n⚙️ Evaluating documents...") |
| 162 | + |
| 163 | + # Accumulation state |
| 164 | + overall_metrics = defaultdict(int) |
| 165 | + field_metrics = defaultdict(lambda: defaultdict(int)) |
| 166 | + processed = 0 |
| 167 | + errors = [] |
| 168 | + |
| 169 | + for _, row in df.iterrows(): |
| 170 | + doc_id = row[args.doc_id_column] |
| 171 | + |
| 172 | + # Find matching result |
| 173 | + result_key = None |
| 174 | + for key in [doc_id, f"{doc_id}.pdf", doc_id.replace('.pdf', '')]: |
| 175 | + if key in inference_results: |
| 176 | + result_key = key |
| 177 | + break |
| 178 | + |
| 179 | + if not result_key: |
| 180 | + continue |
| 181 | + |
| 182 | + try: |
| 183 | + # Parse ground truth and get actual results |
| 184 | + expected = json.loads(row[args.labels_column]) |
| 185 | + actual = inference_results[result_key] |
| 186 | + |
| 187 | + # Normalize to list format |
| 188 | + expected = normalize_to_list_format(expected) |
| 189 | + actual = normalize_to_list_format(actual) |
| 190 | + |
| 191 | + # Create section and evaluate |
| 192 | + section = Section(section_id="1", classification="fcc_invoice", page_ids=["1"]) |
| 193 | + result = service.evaluate_section(section, expected, actual) |
| 194 | + |
| 195 | + # Accumulate metrics from attributes |
| 196 | + for attr in result.attributes: |
| 197 | + field = attr.name |
| 198 | + exp_val = attr.expected |
| 199 | + act_val = attr.actual |
| 200 | + matched = attr.matched |
| 201 | + |
| 202 | + # Determine metric type |
| 203 | + exp_empty = not exp_val or (isinstance(exp_val, list) and len(exp_val) == 0) |
| 204 | + act_empty = not act_val or (isinstance(act_val, list) and len(act_val) == 0) |
| 205 | + |
| 206 | + if exp_empty and act_empty: |
| 207 | + overall_metrics["tn"] += 1 |
| 208 | + field_metrics[field]["tn"] += 1 |
| 209 | + elif exp_empty and not act_empty: |
| 210 | + overall_metrics["fp"] += 1 |
| 211 | + overall_metrics["fp1"] += 1 |
| 212 | + field_metrics[field]["fp"] += 1 |
| 213 | + field_metrics[field]["fp1"] += 1 |
| 214 | + elif not exp_empty and act_empty: |
| 215 | + overall_metrics["fn"] += 1 |
| 216 | + field_metrics[field]["fn"] += 1 |
| 217 | + elif matched: |
| 218 | + overall_metrics["tp"] += 1 |
| 219 | + field_metrics[field]["tp"] += 1 |
| 220 | + else: |
| 221 | + overall_metrics["fp"] += 1 |
| 222 | + overall_metrics["fp2"] += 1 |
| 223 | + field_metrics[field]["fp"] += 1 |
| 224 | + field_metrics[field]["fp2"] += 1 |
| 225 | + |
| 226 | + # Save individual result |
| 227 | + result_file = output_dir / f"{doc_id}.json" |
| 228 | + result_data = { |
| 229 | + "doc_id": doc_id, |
| 230 | + "metrics": result.metrics, |
| 231 | + "attributes": [ |
| 232 | + { |
| 233 | + "name": a.name, |
| 234 | + "expected": a.expected, |
| 235 | + "actual": a.actual, |
| 236 | + "matched": a.matched, |
| 237 | + "score": float(a.score), |
| 238 | + "reason": a.reason |
| 239 | + } |
| 240 | + for a in result.attributes |
| 241 | + ] |
| 242 | + } |
| 243 | + with open(result_file, 'w') as f: |
| 244 | + json.dump(to_json_serializable(result_data), f, indent=2) |
| 245 | + |
| 246 | + processed += 1 |
| 247 | + |
| 248 | + except Exception as e: |
| 249 | + errors.append({"doc_id": doc_id, "error": str(e)}) |
| 250 | + print(f" ✗ Error evaluating {doc_id}: {e}") |
| 251 | + |
| 252 | + print(f"✓ Completed evaluation of {processed} documents") |
| 253 | + |
| 254 | + # Calculate metrics |
| 255 | + def calc_metrics(cm): |
| 256 | + tp, fp, tn, fn = cm["tp"], cm["fp"], cm["tn"], cm["fn"] |
| 257 | + total = tp + fp + tn + fn |
| 258 | + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| 259 | + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| 260 | + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 |
| 261 | + accuracy = (tp + tn) / total if total > 0 else 0.0 |
| 262 | + return { |
| 263 | + "precision": precision, "recall": recall, "f1_score": f1, "accuracy": accuracy, |
| 264 | + "tp": tp, "fp": fp, "tn": tn, "fn": fn, |
| 265 | + "fp1": cm["fp1"], "fp2": cm["fp2"], "total": total |
| 266 | + } |
| 267 | + |
| 268 | + overall = calc_metrics(overall_metrics) |
| 269 | + fields = {field: calc_metrics(cm) for field, cm in field_metrics.items()} |
| 270 | + |
| 271 | + # Print results |
| 272 | + print("\n" + "=" * 80) |
| 273 | + print("AGGREGATED RESULTS") |
| 274 | + print("=" * 80) |
| 275 | + print(f"\n📊 Summary: {processed} processed, {len(errors)} errors") |
| 276 | + print(f"\n📈 Overall Metrics:") |
| 277 | + print(f" Precision: {overall['precision']:.4f}") |
| 278 | + print(f" Recall: {overall['recall']:.4f}") |
| 279 | + print(f" F1 Score: {overall['f1_score']:.4f}") |
| 280 | + print(f" Accuracy: {overall['accuracy']:.4f}") |
| 281 | + print(f"\n Confusion Matrix:") |
| 282 | + print(f" TP: {overall['tp']:6d} | FP: {overall['fp']:6d}") |
| 283 | + print(f" FN: {overall['fn']:6d} | TN: {overall['tn']:6d}") |
| 284 | + print(f" FP1: {overall['fp1']:6d} | FP2: {overall['fp2']:6d}") |
| 285 | + |
| 286 | + # Top fields |
| 287 | + sorted_fields = sorted(fields.items(), key=lambda x: x[1]["f1_score"], reverse=True) |
| 288 | + print(f"\n📋 Field-Level Metrics (Top 10):") |
| 289 | + print(f" {'Field':<40} {'Precision':>10} {'Recall':>10} {'F1':>10}") |
| 290 | + print(f" {'-'*40} {'-'*10} {'-'*10} {'-'*10}") |
| 291 | + for field, metrics in sorted_fields[:10]: |
| 292 | + print(f" {field:<40} {metrics['precision']:>10.4f} {metrics['recall']:>10.4f} {metrics['f1_score']:>10.4f}") |
| 293 | + |
| 294 | + # Save aggregated results |
| 295 | + output_file = output_dir / "aggregated_metrics.json" |
| 296 | + with open(output_file, 'w') as f: |
| 297 | + json.dump({ |
| 298 | + "summary": {"documents_processed": processed, "errors": len(errors)}, |
| 299 | + "overall_metrics": overall, |
| 300 | + "field_metrics": fields, |
| 301 | + "errors": errors, |
| 302 | + "stickler_config_used": stickler_config |
| 303 | + }, f, indent=2) |
| 304 | + |
| 305 | + print(f"\n💾 Results saved to {output_dir}") |
| 306 | + print("=" * 80) |
| 307 | + |
| 308 | + |
| 309 | +if __name__ == "__main__": |
| 310 | + main() |
0 commit comments