From 0d37b51dbca990a9a0b0ee80f356155a9d2c7f96 Mon Sep 17 00:00:00 2001 From: Anooshka Pendyal Date: Mon, 3 Nov 2025 19:28:39 -0500 Subject: [PATCH] Draft: moving over citation_exists.py requirement and corresponding test; also added eyecite useage for LLM output parsing --- mellea_contribs/reqlib/citation_exists.py | 186 ++++++++++++++++++++++ test/test_citation_exists.py | 52 ++++++ 2 files changed, 238 insertions(+) create mode 100644 mellea_contribs/reqlib/citation_exists.py create mode 100644 test/test_citation_exists.py diff --git a/mellea_contribs/reqlib/citation_exists.py b/mellea_contribs/reqlib/citation_exists.py new file mode 100644 index 0000000..40c5577 --- /dev/null +++ b/mellea_contribs/reqlib/citation_exists.py @@ -0,0 +1,186 @@ +from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.base import Context, CBlock + +import json +import os +import re +from eyecite import get_citations, clean_text +from typing import Any + +# region: citation_exists function and helpers + +def normalize_case_name(name) -> str: + """ + Converts a case name to a standard format. + + Args: + name: A string representing the case name. + + Returns: + A normalized case name. + """ + # 1. Lowercase everything + name = name.lower() + + # 2. Normalize 'vs', 'vs.', 'v', 'versus' to 'v.' + name = re.sub(r'\b(vs\.?|versus|v)(?!\.)\b', 'v.', name) + + # 3. Remove all non-alphanumeric characters except periods, spaces, and apostrophes + name = re.sub(r"[^a-z0-9.& ']+", '', name) + + # 4. Replace multiple spaces with a single space + name = re.sub(r'\s+', ' ', name) + + return name.strip() + +# might not be needed +# def ensure_list_of_dicts(obj: Any) -> list[dict]: +# """ +# Normalize any JSON-like object into a list of dictionaries. + +# Accepts: +# - A JSON string (object or array) +# - A single dict +# - A list of dicts + +# Args: +# obj: Any data type, ideally something that can unpacked into a dictionary + +# Returns: +# The unpacked object in list of dictionary form or raises an error. +# """ +# # JSON string +# if isinstance(obj, str): +# try: +# obj = json.loads(obj) +# except json.JSONDecodeError as e: +# raise ValueError(f"Invalid JSON string: {e!s}") + +# # Single dict +# if isinstance(obj, dict): +# return [obj] + +# # List of dicts +# if isinstance(obj, list): +# if all(isinstance(item, dict) for item in obj): +# return obj +# else: +# raise ValueError("List contains non-dictionary elements") + +# raise TypeError(f"Unsupported metadata format: {type(obj)}") + +# alternatively: +# should this take in last_output instead of the whole context? +# get case name: take LLM output and extract case name --> a string which you get from ctx.last_output() is the input +# so the argument should be ctx.last_output.value: str + +def extract_case_names(ctx: Context) -> list[str]: + """ + Given an LLM output, use eyecite to parse the text and collect case names. + + Args: + ctx: An LLM output that may contain multiple citations. + + Returns: + A list of case names. + """ + # should i clean text?? + + # install hyperscan if not already installed + # !pip install hyperscan + # tokenizer = HyperscanTokenizer(cache_dir=".test_cache") + # citations = get_citations(cleaned_text, tokenizer=tokenizer) + + # or this? + # cleaned_text = clean_text(text, ["html", "all_whitespace"]) + # citations = get_citations(cleaned_text) + + # get_citations outputs a list of citations + citations = get_citations(ctx.last_output().value) + case_names = set() + + for citation in citations: + plaintiff = citation.metadata.get("plaintiff") + defendant = citation.metadata.get("defendant") + if plaintiff and defendant: + case_names.add(f"{plaintiff} v. {defendant}") + # name = citation.metadata['plaintiff'] + " v. " + citation.metadata['defendant'] + # case_names.add(name) + + return list(case_names) + +def citation_exists(ctx: Context, case_metadata: list[dict]) -> ValidationResult: + """ + Given an LLM output and a list of dictionaries, checks that list (which represents a collection of + case metadata json files) to see if the given case names can be found in it. + + Args: + ctx: Context that contains the case names we're checking for + case_metadata: a list of dictionaries which represents a collection of case metadata json files + + Returns: + A validation result indicating if a match was found between given case names and database + """ + if ctx is None: + return ValidationResult(False, reason="No context provided in output") + + # 1) this will spit out a bunch of words --> look through to extract case names + # 2) use eyecite (might have to do some conversion) + last_output = ctx.last_output() + + # if last_output is None or not getattr(output, "value", None): + if last_output is None: + return ValidationResult(False, reason="No last output found in context") + + # 3) run checking + # call get_case_name func + case_names = extract_case_names(ctx) + + if not case_names or not isinstance(case_names, list[str]): + return ValidationResult(False, reason="No case names provided in output") + + normalized_case_names = [normalize_case_name(case_name) for case_name in case_names] + + case_names = set() + case_name_abb = set() + + # add name and name_abbreviation from the database + for case in case_metadata: + if 'name' in case: + case_names.add(normalize_case_name(case['name'])) + if 'name_abbreviation' in case: + case_name_abb.add(normalize_case_name(case['name_abbreviation'])) + + # Check both name and name_abbreviation + for normalized_case_name in normalized_case_names: + if normalized_case_name not in case_names and normalized_case_name not in case_name_abb: + # probably want to change this to the actual case name at some point + # maybe keep a tuple structure or something + return ValidationResult(False, reason=f"'{normalized_case_name}' not found in database") + + return ValidationResult(True, reason="All case names found in database") + + # check if this code chunk is right later + # db_names = {normalize_case_name(c["name"]) for c in case_metadata if "name" in c} + # db_abbrevs = { + # normalize_case_name(c["name_abbreviation"]) for c in case_metadata if "name_abbreviation" in c + # } + + # for name in normalized_output_names: + # if name not in db_names and name not in db_abbrevs: + # return ValidationResult(False, reason=f"Case '{name}' not found in database") + + # return ValidationResult(True, reason="All case names found in database") + + +class CaseNameExistsInDatabase(Requirement): + """ + Checks if the output case name exists in the provided case metadata database. + """ + def __init__(self, case_metadata: str): + self._case_metadata = case_metadata + super().__init__( + description="The case name should exist in the provided case metadata database.", + validation_fn=lambda ctx: citation_exists(ctx, self._case_metadata), + ) +# endregion \ No newline at end of file diff --git a/test/test_citation_exists.py b/test/test_citation_exists.py new file mode 100644 index 0000000..3298b81 --- /dev/null +++ b/test/test_citation_exists.py @@ -0,0 +1,52 @@ +import pytest +from mellea.mellea.stdlib.reqlib.citation_exists import normalize_case_name, citation_exists + +# Mock context for testing citation_exists + +# make up my own model outputs + +# can just check if case names are in one json file +class MockContext: + def __init__(self, case_name): + self._case_name = case_name + + def last_output(self): + return type("MockOutput", (), {"value": self._case_name})() + + +# region: normalize_case_name tests +@pytest.mark.parametrize("raw_name,expected", [ + ("BOB VS SHMEEGUS", "bob v. shmeegus"), + ("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", "william payne executor of john payne v. william dudley executor of fleet"), + ("Ozwald v. Dickinson's Ex'rs", "ozwald v. dickinson's ex'rs"), + ("Fox & al. v. Cosby", "fox & al. v. cosby"), + ("Groves v. Graves", "groves v. graves"), + ("Ozwald, Deniston, & Co. v. Dickinson's Ex'rs", "ozwald deniston & co. v. dickinson's ex'rs"), + ("Bobby- versus shmeegy", "bobby v. shmeegy") +]) + +def test_normalize_case_name(raw_name, expected): + assert normalize_case_name(raw_name) == expected +# endregion + +# region: citation_exists tests +@pytest.mark.parametrize("case_name,expected", [ + ("Bob v. Shmeegus", False), + ("Gimli versus Legolas", False), + ("Groves v. Graves", True), + ("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", True), + ("Payne v. Dudley", True), + ("Fox & al. v. Cosby", True), + ("Fox v. Cosby", True), +]) + +def test_citation_exists(tmp_path, case_name, expected): + # create mock context + ctx = MockContext(case_name) + # path to metadata folder + # db_folder = "/Users/anooshkapendyal/Desktop/mellea/mellea/test/stdlib_basics/legal/cases_metadata" + + result = citation_exists(ctx, db_folder) + assert result.as_bool() == expected, result.reason + +# endregion \ No newline at end of file