Skip to content

Commit cf4a848

Browse files
authored
Merge pull request #39 from SAFEHR-data/tomr/ocr-analysis
Adds OCR evaluation script / Fixes WER calculation
2 parents 4845baa + b4c9813 commit cf4a848

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

.github/workflows/tests.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ jobs:
3535
- name: Cache docker/setup-buildx
3636
uses: docker/setup-buildx-action@v3
3737

38+
- name: Increase disk space available for building images
39+
run: |
40+
sudo rm -rf \
41+
/usr/share/dotnet \
42+
/usr/local/lib/android \
43+
/usr/local/.ghcup \
44+
/opt/ghc \
45+
"$AGENT_TOOLSDIRECTORY" \
46+
/usr/local/share/powershell \
47+
/usr/share/swift \
48+
/usr/lib/jvm || true
49+
3850
- name: Disk Usage - initial size
3951
run: |
4052
echo "Disk usage summary:"

src/pyonb/analysis/eval_ocr.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from pyonb.analysis.metrics import cer, ned, wer
99

1010

11-
def read_file(file_path: Path) -> str | dict:
11+
def read_file(file_path: Path, file_encoding: str | None = None) -> str | dict:
1212
"""Read .txt or .json file."""
13-
with Path.open(file_path, "r") as f:
13+
with Path.open(file_path, "r", encoding=file_encoding) as f:
1414
file_type = file_path.suffix.lower()
1515

1616
if file_type == ".json":
@@ -31,14 +31,10 @@ def evaluate_metrics(gt_text: str, ocr_text: str) -> dict:
3131
return {"cer": cer_result, "wer": wer_result, "ned": ned_result}
3232

3333

34-
if __name__ == "__main__":
35-
parser = argparse.ArgumentParser(description="Run and evaluate OCR performance metrics.")
36-
parser.add_argument("-gt", "--ground_truth_file", type=str, required=True, help="[.txt] Path to ground truth file.")
37-
parser.add_argument("-ocr", "--ocr_file", type=str, required=True, help="[.json/.txt] Path to OCR processed file.")
38-
args = parser.parse_args()
39-
40-
gt_file_output = read_file(Path(args.ground_truth_file))
41-
ocr_file_output = read_file(Path(args.ocr_file))
34+
def run(gt_path: Path, ocr_path: Path) -> dict:
35+
"""Run OCR evaluation given ground truth and OCR file paths."""
36+
gt_file_output = read_file(gt_path)
37+
ocr_file_output = read_file(ocr_path)
4238

4339
if isinstance(ocr_file_output, str):
4440
result = evaluate_metrics(str(gt_file_output), str(ocr_file_output))
@@ -49,4 +45,14 @@ def evaluate_metrics(gt_text: str, ocr_text: str) -> dict:
4945
msg = "OCR file is not .txt or .json."
5046
raise TypeError(msg)
5147

52-
print(f"OCR Evaluation results:\n{result}") # noqa: T201
48+
return result
49+
50+
51+
if __name__ == "__main__":
52+
parser = argparse.ArgumentParser(description="Run and evaluate OCR performance metrics.")
53+
parser.add_argument("-gt", "--ground_truth_file", type=str, required=True, help="[.txt] Path to ground truth file.")
54+
parser.add_argument("-ocr", "--ocr_file", type=str, required=True, help="[.json/.txt] Path to OCR processed file.")
55+
args = parser.parse_args()
56+
57+
results = run(Path(args.ground_truth_file), Path(args.ocr_file))
58+
print(f"OCR Evaluation results:\n{results}") # noqa: T201

src/pyonb/analysis/metrics.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def cer(gt: str, pred: str) -> float:
88
Character Error Rate (CER): edit distance / length of ground truth.
99
1010
CER = 0 - Perfect character match
11-
CER = 1 - Completely different
11+
CER > 0 - ratio of character edits needed; values > 1.0 indicate more edits than original characters
1212
"""
1313
if not gt:
1414
return float("inf") if pred else 0.0
@@ -20,11 +20,29 @@ def wer(gt: str, pred: str) -> float:
2020
Word Error Rate (WER): edit distance over tokenized words.
2121
2222
WER = 0 - Perfect word match
23-
WER = 1 - Completely different
23+
WER > 0 - ratio of word edits needed; values > 1.0 indicate more edits than original words'
2424
"""
2525
gt_words = gt.split()
2626
pred_words = pred.split()
27-
return round(Levenshtein.distance(" ".join(gt_words), " ".join(pred_words)) / max(1, len(gt_words)), 3)
27+
28+
# Initialise dynamic programming matrix for edit distance calculation
29+
dp = [[0] * (len(pred_words) + 1) for _ in range(len(gt_words) + 1)]
30+
31+
for i in range(len(gt_words) + 1):
32+
dp[i][0] = i
33+
for j in range(len(pred_words) + 1):
34+
dp[0][j] = j
35+
36+
for i in range(1, len(gt_words) + 1):
37+
for j in range(1, len(pred_words) + 1):
38+
cost = 0 if gt_words[i - 1] == pred_words[j - 1] else 1
39+
dp[i][j] = min(
40+
dp[i - 1][j] + 1, # deletion
41+
dp[i][j - 1] + 1, # insertion
42+
dp[i - 1][j - 1] + cost, # substitution
43+
)
44+
45+
return round(dp[len(gt_words)][len(pred_words)] / max(1, len(gt_words)), 3)
2846

2947

3048
def emr(gt_list: list[str], pred_list: list[str]) -> float:
@@ -41,7 +59,7 @@ def ned(gt: str, pred: str) -> float:
4159
Normalized Edit Distance: edit distance / max length.
4260
4361
NED = 0 - Perfect match, strings identical
44-
NED = 1 - Maximum dissimilarity
62+
NED = 1 - all characters changed to make strings identical
4563
"""
4664
max_len = max(len(gt), len(pred))
4765
if max_len == 0:

0 commit comments

Comments
 (0)