-
Notifications
You must be signed in to change notification settings - Fork 6
Draft: moving over citation_exists.py requirement #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| from mellea.stdlib.requirement import Requirement, ValidationResult | ||
| from mellea.stdlib.base import Context, CBlock | ||
|
|
||
| import json | ||
| import os | ||
| import re | ||
| from eyecite import get_citations, clean_text | ||
| from typing import Any | ||
|
|
||
| # region: citation_exists function and helpers | ||
|
|
||
| def normalize_case_name(name) -> str: | ||
| """ | ||
| Converts a case name to a standard format. | ||
|
|
||
| Args: | ||
| name: A string representing the case name. | ||
|
|
||
| Returns: | ||
| A normalized case name. | ||
| """ | ||
| # 1. Lowercase everything | ||
| name = name.lower() | ||
|
|
||
| # 2. Normalize 'vs', 'vs.', 'v', 'versus' to 'v.' | ||
| name = re.sub(r'\b(vs\.?|versus|v)(?!\.)\b', 'v.', name) | ||
|
|
||
| # 3. Remove all non-alphanumeric characters except periods, spaces, and apostrophes | ||
| name = re.sub(r"[^a-z0-9.& ']+", '', name) | ||
|
|
||
| # 4. Replace multiple spaces with a single space | ||
| name = re.sub(r'\s+', ' ', name) | ||
|
|
||
| return name.strip() | ||
|
|
||
| # might not be needed | ||
| # def ensure_list_of_dicts(obj: Any) -> list[dict]: | ||
| # """ | ||
| # Normalize any JSON-like object into a list of dictionaries. | ||
|
|
||
| # Accepts: | ||
| # - A JSON string (object or array) | ||
| # - A single dict | ||
| # - A list of dicts | ||
|
|
||
| # Args: | ||
| # obj: Any data type, ideally something that can unpacked into a dictionary | ||
|
|
||
| # Returns: | ||
| # The unpacked object in list of dictionary form or raises an error. | ||
| # """ | ||
| # # JSON string | ||
| # if isinstance(obj, str): | ||
| # try: | ||
| # obj = json.loads(obj) | ||
| # except json.JSONDecodeError as e: | ||
| # raise ValueError(f"Invalid JSON string: {e!s}") | ||
|
|
||
| # # Single dict | ||
| # if isinstance(obj, dict): | ||
| # return [obj] | ||
|
|
||
| # # List of dicts | ||
| # if isinstance(obj, list): | ||
| # if all(isinstance(item, dict) for item in obj): | ||
| # return obj | ||
| # else: | ||
| # raise ValueError("List contains non-dictionary elements") | ||
|
|
||
| # raise TypeError(f"Unsupported metadata format: {type(obj)}") | ||
|
|
||
| # alternatively: | ||
| # should this take in last_output instead of the whole context? | ||
| # get case name: take LLM output and extract case name --> a string which you get from ctx.last_output() is the input | ||
| # so the argument should be ctx.last_output.value: str | ||
|
|
||
| def extract_case_names(ctx: Context) -> list[str]: | ||
| """ | ||
| Given an LLM output, use eyecite to parse the text and collect case names. | ||
|
|
||
| Args: | ||
| ctx: An LLM output that may contain multiple citations. | ||
|
|
||
| Returns: | ||
| A list of case names. | ||
| """ | ||
| # should i clean text?? | ||
|
|
||
| # install hyperscan if not already installed | ||
| # !pip install hyperscan | ||
| # tokenizer = HyperscanTokenizer(cache_dir=".test_cache") | ||
| # citations = get_citations(cleaned_text, tokenizer=tokenizer) | ||
|
|
||
| # or this? | ||
| # cleaned_text = clean_text(text, ["html", "all_whitespace"]) | ||
| # citations = get_citations(cleaned_text) | ||
|
Comment on lines
+89
to
+96
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's going on here? |
||
|
|
||
| # get_citations outputs a list of citations | ||
| citations = get_citations(ctx.last_output().value) | ||
| case_names = set() | ||
|
|
||
| for citation in citations: | ||
| plaintiff = citation.metadata.get("plaintiff") | ||
| defendant = citation.metadata.get("defendant") | ||
| if plaintiff and defendant: | ||
| case_names.add(f"{plaintiff} v. {defendant}") | ||
| # name = citation.metadata['plaintiff'] + " v. " + citation.metadata['defendant'] | ||
| # case_names.add(name) | ||
|
Comment on lines
+103
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| return list(case_names) | ||
|
|
||
| def citation_exists(ctx: Context, case_metadata: list[dict]) -> ValidationResult: | ||
| """ | ||
| Given an LLM output and a list of dictionaries, checks that list (which represents a collection of | ||
| case metadata json files) to see if the given case names can be found in it. | ||
|
|
||
| Args: | ||
| ctx: Context that contains the case names we're checking for | ||
| case_metadata: a list of dictionaries which represents a collection of case metadata json files | ||
|
|
||
| Returns: | ||
| A validation result indicating if a match was found between given case names and database | ||
| """ | ||
| if ctx is None: | ||
| return ValidationResult(False, reason="No context provided in output") | ||
|
|
||
| # 1) this will spit out a bunch of words --> look through to extract case names | ||
| # 2) use eyecite (might have to do some conversion) | ||
| last_output = ctx.last_output() | ||
|
|
||
| # if last_output is None or not getattr(output, "value", None): | ||
| if last_output is None: | ||
| return ValidationResult(False, reason="No last output found in context") | ||
|
|
||
| # 3) run checking | ||
| # call get_case_name func | ||
| case_names = extract_case_names(ctx) | ||
|
|
||
| if not case_names or not isinstance(case_names, list[str]): | ||
| return ValidationResult(False, reason="No case names provided in output") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should probably return |
||
|
|
||
| normalized_case_names = [normalize_case_name(case_name) for case_name in case_names] | ||
|
|
||
| case_names = set() | ||
| case_name_abb = set() | ||
|
|
||
| # add name and name_abbreviation from the database | ||
| for case in case_metadata: | ||
| if 'name' in case: | ||
| case_names.add(normalize_case_name(case['name'])) | ||
| if 'name_abbreviation' in case: | ||
| case_name_abb.add(normalize_case_name(case['name_abbreviation'])) | ||
|
Comment on lines
+148
to
+152
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This approach seems like it will make a lot of errors. What about cases where the case name isn't verbatim? E.g., sometimes if there is a large set of parties on one side or the other there will be an abbreviation of that in the cite. State names are often also different in the formal cite vs how they're cited inline. Additionally, there is a lot of string manipulation here that I think can be streamlined or done in a more principles way. Should we implement this as something like |
||
|
|
||
| # Check both name and name_abbreviation | ||
| for normalized_case_name in normalized_case_names: | ||
| if normalized_case_name not in case_names and normalized_case_name not in case_name_abb: | ||
| # probably want to change this to the actual case name at some point | ||
| # maybe keep a tuple structure or something | ||
| return ValidationResult(False, reason=f"'{normalized_case_name}' not found in database") | ||
|
|
||
| return ValidationResult(True, reason="All case names found in database") | ||
|
|
||
| # check if this code chunk is right later | ||
| # db_names = {normalize_case_name(c["name"]) for c in case_metadata if "name" in c} | ||
| # db_abbrevs = { | ||
| # normalize_case_name(c["name_abbreviation"]) for c in case_metadata if "name_abbreviation" in c | ||
| # } | ||
|
|
||
| # for name in normalized_output_names: | ||
| # if name not in db_names and name not in db_abbrevs: | ||
| # return ValidationResult(False, reason=f"Case '{name}' not found in database") | ||
|
|
||
| # return ValidationResult(True, reason="All case names found in database") | ||
|
Comment on lines
+163
to
+173
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented out code. |
||
|
|
||
|
|
||
| class CaseNameExistsInDatabase(Requirement): | ||
| """ | ||
| Checks if the output case name exists in the provided case metadata database. | ||
| """ | ||
| def __init__(self, case_metadata: str): | ||
| self._case_metadata = case_metadata | ||
| super().__init__( | ||
| description="The case name should exist in the provided case metadata database.", | ||
| validation_fn=lambda ctx: citation_exists(ctx, self._case_metadata), | ||
| ) | ||
| # endregion | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| import pytest | ||
| from mellea.mellea.stdlib.reqlib.citation_exists import normalize_case_name, citation_exists | ||
|
|
||
| # Mock context for testing citation_exists | ||
|
|
||
| # make up my own model outputs | ||
|
|
||
| # can just check if case names are in one json file | ||
| class MockContext: | ||
| def __init__(self, case_name): | ||
| self._case_name = case_name | ||
|
|
||
| def last_output(self): | ||
| return type("MockOutput", (), {"value": self._case_name})() | ||
|
|
||
|
|
||
| # region: normalize_case_name tests | ||
| @pytest.mark.parametrize("raw_name,expected", [ | ||
| ("BOB VS SHMEEGUS", "bob v. shmeegus"), | ||
| ("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", "william payne executor of john payne v. william dudley executor of fleet"), | ||
| ("Ozwald v. Dickinson's Ex'rs", "ozwald v. dickinson's ex'rs"), | ||
| ("Fox & al. v. Cosby", "fox & al. v. cosby"), | ||
| ("Groves v. Graves", "groves v. graves"), | ||
| ("Ozwald, Deniston, & Co. v. Dickinson's Ex'rs", "ozwald deniston & co. v. dickinson's ex'rs"), | ||
| ("Bobby- versus shmeegy", "bobby v. shmeegy") | ||
| ]) | ||
|
|
||
| def test_normalize_case_name(raw_name, expected): | ||
| assert normalize_case_name(raw_name) == expected | ||
| # endregion | ||
|
|
||
| # region: citation_exists tests | ||
| @pytest.mark.parametrize("case_name,expected", [ | ||
| ("Bob v. Shmeegus", False), | ||
| ("Gimli versus Legolas", False), | ||
| ("Groves v. Graves", True), | ||
| ("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", True), | ||
| ("Payne v. Dudley", True), | ||
| ("Fox & al. v. Cosby", True), | ||
| ("Fox v. Cosby", True), | ||
| ]) | ||
|
|
||
| def test_citation_exists(tmp_path, case_name, expected): | ||
| # create mock context | ||
| ctx = MockContext(case_name) | ||
| # path to metadata folder | ||
| # db_folder = "/Users/anooshkapendyal/Desktop/mellea/mellea/test/stdlib_basics/legal/cases_metadata" | ||
|
|
||
| result = citation_exists(ctx, db_folder) | ||
| assert result.as_bool() == expected, result.reason | ||
|
|
||
| # endregion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.