Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions backend/app/matcher/voter_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,32 @@
import os
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from enum import StrEnum

import numpy as np
import pandas as pd
from pandas import DataFrame
from rapidfuzz import fuzz
from tqdm import tqdm

from ..voters.voter_regions import DCVoterRecordSpec, VoterRecordProcessor


class MatchedColumn(StrEnum):
OCR_NAME = "OCR Name"
OCR_ADDRESS = "OCR Address"
MATCHED_NAME = "Matched Name"
MATCHED_ADDRESS = "Matched Address"
DATE = "Date"
MATCHED_SCORE = "Match Score"
VALID = "Valid"
PAGE_NUMBER = "Page Number"
ROW_NUMBER = "Row Number"
FILE_NAME = "Filename"


_voter_record_processor: VoterRecordProcessor = VoterRecordProcessor()

# load config
with open("config.json", "r") as f:
config = json.load(f)
Expand Down Expand Up @@ -55,20 +74,19 @@ def _create_select_voter_records(voter_records: pd.DataFrame) -> pd.DataFrame:
Returns:
pd.DataFrame: DataFrame with 'Full Name' and 'Full Address' columns
"""

# TODO make it configurable
_voter_record_processor.set_voter_region(DCVoterRecordSpec())

# Create full name by combining first and last names
name_components = ["First_Name", "Last_Name"]
name_components = _voter_record_processor.get_name_components()
voter_records[name_components] = voter_records[name_components].fillna("")
voter_records["Full Name"] = (
voter_records[name_components].astype(str).agg(" ".join, axis=1)
)

# Create full address by combining address components
address_components = [
"Street_Number",
"Street_Name",
"Street_Type",
"Street_Dir_Suffix",
]
address_components = _voter_record_processor.get_address_components()
voter_records[address_components] = voter_records[address_components].fillna("")
voter_records["Full Address"] = (
voter_records[address_components].astype(str).agg(" ".join, axis=1)
Expand Down Expand Up @@ -172,7 +190,7 @@ def create_ocr_matched_df(

Args:
ocr_df (pd.DataFrame): The DataFrame containing OCR results.
select_voter_records (pd.DataFrame): The DataFrame containing voter records.
voter_records (pd.DataFrame): The DataFrame containing voter records.
threshold (float): The threshold for matching.
st_bar (st.progress): The progress bar to display.

Expand Down Expand Up @@ -202,7 +220,9 @@ def create_ocr_matched_df(
batch_results = list(
executor.map(
lambda row: _get_matched_name_address(
row["OCR Name"], row["OCR Address"], select_voter_records
row[MatchedColumn.OCR_NAME],
row[MatchedColumn.OCR_ADDRESS],
select_voter_records,
),
[row for _, row in batch.iterrows()],
)
Expand All @@ -227,26 +247,16 @@ def create_ocr_matched_df(
text=f"Processing batch {batch_start} out of {len(ocr_df) // batch_size + 1} batches",
)

matched_columns: list[str] = [
MatchedColumn.MATCHED_NAME,
MatchedColumn.MATCHED_ADDRESS,
MatchedColumn.MATCHED_SCORE,
]

logger.info("Creating final DataFrame")
match_df = pd.DataFrame(
results, columns=["Matched Name", "Matched Address", "Match Score"]
)
match_df = pd.DataFrame(results, columns=matched_columns)
result_df = pd.concat([ocr_df, match_df], axis=1)
result_df["Valid"] = result_df["Match Score"] >= threshold

# Reorder columns
column_order = [
"OCR Name",
"OCR Address",
"Matched Name",
"Matched Address",
"Date",
"Match Score",
"Valid",
"Page Number",
"Row Number",
"Filename",
]
result_df["Valid"] = result_df[MatchedColumn.MATCHED_SCORE] >= threshold

# Log final statistics
total_valid = result_df["Valid"].sum()
Expand All @@ -255,4 +265,4 @@ def create_ocr_matched_df(
f"Valid matches: {total_valid} ({total_valid / len(result_df) * 100:.1f}%)"
)

return result_df[column_order]
return result_df[list(MatchedColumn)]
2 changes: 1 addition & 1 deletion backend/app/ocr/ocr_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
def _collect_ocr_data(
filedir: str,
filename: str,
max_page_num: int = None,
max_page_num: int | None = None,
batch_size: int = 10,
st_bar=None,
) -> list[dict]:
Expand Down
3 changes: 3 additions & 0 deletions backend/app/voters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .voter_regions import DCVoterRecordSpec, VoterRecordProcessor, VoterRecordSpec

__all__ = ["DCVoterRecordSpec", "VoterRecordSpec", "VoterRecordProcessor"]
54 changes: 54 additions & 0 deletions backend/app/voters/voter_regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import ABC, abstractmethod
from typing import Protocol


class VoterRecordSpec(Protocol):

@abstractmethod
def name_components(self) -> list[str]:
pass

@abstractmethod
def address_components(self) -> list[str]:
pass


class UndefinedVoterRecordSpec:

def name_components(self) -> list[str]:
return []

def address_components(self) -> list[str]:
return []


class DCVoterRecordSpec:
def name_components(self) -> list[str]:
return ["First_name", "Last_name"]

def address_components(self) -> list[str]:
return [
"Street_Number",
"Street_Name",
"Street_Type",
"Street_Dir_Suffix",
]


class VoterRecordProcessor:

current_voter_region: VoterRecordSpec

def __init__(self, voter_region_spec: VoterRecordSpec | None = None):
self.current_voter_region = (
voter_region_spec if voter_region_spec else UndefinedVoterRecordSpec()
)

def set_voter_region(self, voter_region: VoterRecordSpec):
self.current_voter_region = voter_region

def get_name_components(self) -> list[str]:
return self.current_voter_region.name_components()

def get_address_components(self) -> list[str]:
return self.current_voter_region.address_components()
3 changes: 3 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ license = "MIT"
dependencies = [
"dotenv>=0.9.9",
"fastapi[standard]>=0.115.12",
"fitz>=0.0.1.dev2",
"httpx>=0.28.1",
"ipywidgets>=8.1.5",
"langchain-core>=0.3.51",
Expand Down Expand Up @@ -45,7 +46,9 @@ dependencies = [
dev = [
"black>=25.1.0",
"flake8>=7.2.0",
"pyright>=1.1.406",
"pytest>=8.3.5",
"pytest-asyncio>=1.2.0",
"pytest-cov>=6.1.1",
"ruff>=0.11.4",
]
Expand Down
Loading