Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions mellea_contribs/reqlib/citation_exists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from mellea.stdlib.requirement import Requirement, ValidationResult
from mellea.stdlib.base import Context, CBlock

import json
import os
import re
from eyecite import get_citations, clean_text
from typing import Any

# region: citation_exists function and helpers

def normalize_case_name(name) -> str:
"""
Converts a case name to a standard format.

Args:
name: A string representing the case name.

Returns:
A normalized case name.
"""
# 1. Lowercase everything
name = name.lower()

# 2. Normalize 'vs', 'vs.', 'v', 'versus' to 'v.'
name = re.sub(r'\b(vs\.?|versus|v)(?!\.)\b', 'v.', name)

# 3. Remove all non-alphanumeric characters except periods, spaces, and apostrophes
name = re.sub(r"[^a-z0-9.& ']+", '', name)

# 4. Replace multiple spaces with a single space
name = re.sub(r'\s+', ' ', name)

return name.strip()

# might not be needed
# def ensure_list_of_dicts(obj: Any) -> list[dict]:
# """
# Normalize any JSON-like object into a list of dictionaries.

# Accepts:
# - A JSON string (object or array)
# - A single dict
# - A list of dicts

# Args:
# obj: Any data type, ideally something that can unpacked into a dictionary

# Returns:
# The unpacked object in list of dictionary form or raises an error.
# """
# # JSON string
# if isinstance(obj, str):
# try:
# obj = json.loads(obj)
# except json.JSONDecodeError as e:
# raise ValueError(f"Invalid JSON string: {e!s}")

# # Single dict
# if isinstance(obj, dict):
# return [obj]

# # List of dicts
# if isinstance(obj, list):
# if all(isinstance(item, dict) for item in obj):
# return obj
# else:
# raise ValueError("List contains non-dictionary elements")

# raise TypeError(f"Unsupported metadata format: {type(obj)}")

# alternatively:
# should this take in last_output instead of the whole context?
# get case name: take LLM output and extract case name --> a string which you get from ctx.last_output() is the input
# so the argument should be ctx.last_output.value: str
Comment on lines +36 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. What was the purpose of this?
  2. Remove commented out code.


def extract_case_names(ctx: Context) -> list[str]:
"""
Given an LLM output, use eyecite to parse the text and collect case names.

Args:
ctx: An LLM output that may contain multiple citations.

Returns:
A list of case names.
"""
# should i clean text??

# install hyperscan if not already installed
# !pip install hyperscan
# tokenizer = HyperscanTokenizer(cache_dir=".test_cache")
# citations = get_citations(cleaned_text, tokenizer=tokenizer)

# or this?
# cleaned_text = clean_text(text, ["html", "all_whitespace"])
# citations = get_citations(cleaned_text)
Comment on lines +89 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?


# get_citations outputs a list of citations
citations = get_citations(ctx.last_output().value)
case_names = set()

for citation in citations:
plaintiff = citation.metadata.get("plaintiff")
defendant = citation.metadata.get("defendant")
if plaintiff and defendant:
case_names.add(f"{plaintiff} v. {defendant}")
# name = citation.metadata['plaintiff'] + " v. " + citation.metadata['defendant']
# case_names.add(name)
Comment on lines +103 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. What's the purpose of this?
  2. Are these actually canonical names/references?
  3. What happens if you don't have a plaintiff and defendent? Can that ever happen? If not -> assert. If yes -> handle exceptiona lcases.


return list(case_names)

def citation_exists(ctx: Context, case_metadata: list[dict]) -> ValidationResult:
"""
Given an LLM output and a list of dictionaries, checks that list (which represents a collection of
case metadata json files) to see if the given case names can be found in it.

Args:
ctx: Context that contains the case names we're checking for
case_metadata: a list of dictionaries which represents a collection of case metadata json files

Returns:
A validation result indicating if a match was found between given case names and database
"""
if ctx is None:
return ValidationResult(False, reason="No context provided in output")

# 1) this will spit out a bunch of words --> look through to extract case names
# 2) use eyecite (might have to do some conversion)
last_output = ctx.last_output()

# if last_output is None or not getattr(output, "value", None):
if last_output is None:
return ValidationResult(False, reason="No last output found in context")

# 3) run checking
# call get_case_name func
case_names = extract_case_names(ctx)

if not case_names or not isinstance(case_names, list[str]):
return ValidationResult(False, reason="No case names provided in output")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should probably return True. The reason is good.


normalized_case_names = [normalize_case_name(case_name) for case_name in case_names]

case_names = set()
case_name_abb = set()

# add name and name_abbreviation from the database
for case in case_metadata:
if 'name' in case:
case_names.add(normalize_case_name(case['name']))
if 'name_abbreviation' in case:
case_name_abb.add(normalize_case_name(case['name_abbreviation']))
Comment on lines +148 to +152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach seems like it will make a lot of errors.

What about cases where the case name isn't verbatim? E.g., sometimes if there is a large set of parties on one side or the other there will be an abbreviation of that in the cite. State names are often also different in the formal cite vs how they're cited inline.

Additionally, there is a lot of string manipulation here that I think can be streamlined or done in a more principles way. Should we implement this as something like likely_equivalent(c1, c2) where c1 and c2 are eyecite citation objects.


# Check both name and name_abbreviation
for normalized_case_name in normalized_case_names:
if normalized_case_name not in case_names and normalized_case_name not in case_name_abb:
# probably want to change this to the actual case name at some point
# maybe keep a tuple structure or something
return ValidationResult(False, reason=f"'{normalized_case_name}' not found in database")

return ValidationResult(True, reason="All case names found in database")

# check if this code chunk is right later
# db_names = {normalize_case_name(c["name"]) for c in case_metadata if "name" in c}
# db_abbrevs = {
# normalize_case_name(c["name_abbreviation"]) for c in case_metadata if "name_abbreviation" in c
# }

# for name in normalized_output_names:
# if name not in db_names and name not in db_abbrevs:
# return ValidationResult(False, reason=f"Case '{name}' not found in database")

# return ValidationResult(True, reason="All case names found in database")
Comment on lines +163 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented out code.



class CaseNameExistsInDatabase(Requirement):
"""
Checks if the output case name exists in the provided case metadata database.
"""
def __init__(self, case_metadata: str):
self._case_metadata = case_metadata
super().__init__(
description="The case name should exist in the provided case metadata database.",
validation_fn=lambda ctx: citation_exists(ctx, self._case_metadata),
)
# endregion
52 changes: 52 additions & 0 deletions test/test_citation_exists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from mellea.mellea.stdlib.reqlib.citation_exists import normalize_case_name, citation_exists

# Mock context for testing citation_exists

# make up my own model outputs

# can just check if case names are in one json file
class MockContext:
def __init__(self, case_name):
self._case_name = case_name

def last_output(self):
return type("MockOutput", (), {"value": self._case_name})()


# region: normalize_case_name tests
@pytest.mark.parametrize("raw_name,expected", [
("BOB VS SHMEEGUS", "bob v. shmeegus"),
("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", "william payne executor of john payne v. william dudley executor of fleet"),
("Ozwald v. Dickinson's Ex'rs", "ozwald v. dickinson's ex'rs"),
("Fox & al. v. Cosby", "fox & al. v. cosby"),
("Groves v. Graves", "groves v. graves"),
("Ozwald, Deniston, & Co. v. Dickinson's Ex'rs", "ozwald deniston & co. v. dickinson's ex'rs"),
("Bobby- versus shmeegy", "bobby v. shmeegy")
])

def test_normalize_case_name(raw_name, expected):
assert normalize_case_name(raw_name) == expected
# endregion

# region: citation_exists tests
@pytest.mark.parametrize("case_name,expected", [
("Bob v. Shmeegus", False),
("Gimli versus Legolas", False),
("Groves v. Graves", True),
("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", True),
("Payne v. Dudley", True),
("Fox & al. v. Cosby", True),
("Fox v. Cosby", True),
])

def test_citation_exists(tmp_path, case_name, expected):
# create mock context
ctx = MockContext(case_name)
# path to metadata folder
# db_folder = "/Users/anooshkapendyal/Desktop/mellea/mellea/test/stdlib_basics/legal/cases_metadata"

result = citation_exists(ctx, db_folder)
assert result.as_bool() == expected, result.reason

# endregion