Skip to content

Commit a962c75

Browse files
authored
Merge branch 'main' into test-ci-workflow
2 parents 3175871 + c8bd9a1 commit a962c75

File tree

5 files changed

+120828
-0
lines changed

5 files changed

+120828
-0
lines changed

.mergify.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
merge_protections:
2+
- name: Enforce conventional commit
3+
description: Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/
4+
if:
5+
- base = main
6+
success_conditions:
7+
- "title ~=
8+
^(fix|feat|docs|style|refactor|perf|test|build|ci|chore|revert)(?:\\(.+\
9+
\\))?:"
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
from mellea.stdlib.requirement import Requirement, ValidationResult
2+
from mellea.stdlib.base import Context
3+
from eyecite import get_citations
4+
from citeurl import Citator
5+
from typing import Any, Optional
6+
from playwright.sync_api import sync_playwright
7+
from urllib.parse import urljoin, urlparse, parse_qs
8+
9+
import json
10+
import requests
11+
12+
# region citation_exists helpers
13+
14+
"""
15+
Validator: Ensure that every case-law citation in an LLM output corresponds to a real case in the
16+
provided case metadata database.
17+
18+
Process:
19+
1. Extract citations from LLM output using citeurl.
20+
2. Map citation objects to URLs.
21+
3a. For each case.law URL:
22+
- Fetch JSON metadata.
23+
- Compare its case ID against the known database.
24+
3b. For each non case.law URL:
25+
- Run the original text through eyecite to extract volume, reporter, and page.
26+
- Check if any cases in the database match based on this information.
27+
4. If any citation fails, return ValidationResult(False).
28+
5. If all succeed, return ValidationResult(True).
29+
"""
30+
31+
def text_to_urls(text: str) -> dict[str, str] | ValidationResult:
32+
"""
33+
Extracts all citation URLs from the given text using citeurl.
34+
35+
Args:
36+
text: An LLM output
37+
38+
Returns:
39+
A dictionary of citation URLs and the corresponding text.
40+
41+
Behavior:
42+
If a citation does not have a URL attribute, we return a ValidationResult(False)
43+
so that the parent validator can fail accordingly.
44+
"""
45+
citator = Citator()
46+
citations = citator.list_cites(text)
47+
48+
urls = {}
49+
errors = []
50+
51+
for citation in citations:
52+
if hasattr(citation, "URL") and citation.URL:
53+
# Map the URL to the text corresponding to the citation
54+
urls[citation.URL] = citation.text
55+
else:
56+
# Record a descriptive error about the invalid citation object
57+
errors.append(f"Citation has no URL attribute: {repr(citation)}")
58+
59+
if errors:
60+
# Raise one combined error
61+
error_msg = "Some citations did not contain URLs:\n" + "\n".join(errors)
62+
return ValidationResult(False, reason=error_msg)
63+
64+
return urls
65+
66+
67+
def extract_case_metadata_url(case_url: str) -> ValidationResult | str:
68+
"""
69+
Converts a case.law URL to the corresponding static JSON metadata URL.
70+
71+
Args:
72+
case_url: A cite.case.law page
73+
74+
Returns:
75+
A URL to the JSON metadata for the case or a false ValidationResult if the link cannot be found
76+
"""
77+
# Take the full input URL and split into structured components
78+
parsed = urlparse(case_url)
79+
# Turn the query part into a dictionary
80+
params = parse_qs(parsed.query)
81+
82+
# Use None as a fallback if a value is missing
83+
reporter = params.get("reporter", [None])[0]
84+
volume = params.get("volume", [None])[0]
85+
case = params.get("case", [None])[0]
86+
87+
if not reporter or not volume or not case:
88+
# Use playwright if URL parsing doesn't work
89+
with sync_playwright() as pw:
90+
browser = pw.chromium.launch()
91+
page = browser.new_page()
92+
page.goto(case_url)
93+
94+
# Wait for the metadata link to appear
95+
link = page.wait_for_selector("a:has-text('Download case metadata')")
96+
browser.close()
97+
98+
if not link:
99+
return ValidationResult(False, reason=f"No metadata link found on page: {case_url}")
100+
101+
# Extract relative href
102+
href = link.get_attribute("href")
103+
if not href:
104+
return ValidationResult(False, reason=f"Metadata link missing href attribute on page: {case_url}")
105+
106+
# Build the absolute metadata URL
107+
return urljoin(case_url, href)
108+
109+
return f"https://static.case.law/{reporter}/{volume}/cases/{case}.json"
110+
111+
112+
def metadata_url_to_json(metadata_url: str) -> dict:
113+
"""
114+
Fetches JSON metadata for a case.
115+
116+
Args:
117+
metadata_url: Fully-qualified URL to metadata.json
118+
119+
Returns:
120+
A dictionary representing the JSON metadata.
121+
"""
122+
resp = requests.get(metadata_url)
123+
resp.raise_for_status()
124+
return resp.json()
125+
126+
127+
def collect_ids_in_database(database: list[dict]) -> set:
128+
"""
129+
Collects all case IDs from the provided caselaw metadata.
130+
131+
Args:
132+
database: A list of case dictionaries loaded from a caselaw JSON dataset.
133+
134+
Returns:
135+
A set of all unique case IDs.
136+
"""
137+
return {case["id"] for case in database}
138+
139+
140+
def parse_db_cite(cite: str) -> tuple:
141+
"""
142+
Given a citation in the form of a string, return a normalized tuple breaking the
143+
volume, reporter, and page into distinct parts.
144+
145+
Args:
146+
cite: A string representing the citation found in the text.
147+
148+
Returns:
149+
A tuple containing the volume, normalized reporter, and page of a citation.
150+
"""
151+
# TODO: Could mishandle the following: “U. S.”, “S. Ct.”, “F. Supp. 2d”
152+
parts = cite.split()
153+
154+
# If the citation has less then 3 parts, it is likely irregular
155+
if len(parts) < 3:
156+
return None
157+
158+
volume = parts[0]
159+
page = parts[-1]
160+
reporter = " ".join(parts[1:-1])
161+
normalized_reporter = reporter.lower().replace(".", "")
162+
163+
return (volume, normalized_reporter, page)
164+
165+
166+
def build_citation_index(database: list[dict]) -> set[tuple]:
167+
"""
168+
Extract all of the citations in the database for easy comparison.
169+
170+
Args:
171+
database: A list of case dictionaries loaded from a caselaw JSON dataset.
172+
173+
Returns:
174+
A set of normalized tuples.
175+
"""
176+
index = set()
177+
178+
for case in database:
179+
# There can be multiple citations for each case
180+
for c in case.get("citations", []):
181+
parsed = parse_db_cite(c["cite"])
182+
if parsed:
183+
index.add(parsed)
184+
185+
return index
186+
187+
188+
def non_caselaw_citation_exists(text: str, database: list[dict]) -> bool | ValidationResult:
189+
"""
190+
Given the text corresponding to a citation, check whether that citation can matched to
191+
the cases in the database.
192+
193+
We first use a deterministic approach, like by matching against citations in the database.
194+
Then, we fuzzy match across features like case name, volume, and year.
195+
Finally, we resort to using LLM-as-a-judge to determine if a match exists.
196+
197+
Args:
198+
text: A string containing a citation.
199+
database: A list of case dictionaries loaded from a caselaw JSON dataset.
200+
201+
Returns:
202+
Boolean indicating whether a match was found or ValidationResult if there was an error.
203+
"""
204+
# The citations field in the original database represents how each case can be cited,
205+
# but it's not exhaustive.
206+
207+
# Get_citations is an eyecite function (extracts information from text)
208+
citations = get_citations(text)
209+
# Build our database in the proper format for easy access
210+
citation_index = build_citation_index(database)
211+
212+
# Return False ValidationResult if multiple or no citations have been found in the text.
213+
if len(citations) != 1:
214+
return ValidationResult(False, reason="Error from parsing citations with eyecite.")
215+
216+
try:
217+
groups = citations[0].groups
218+
vol = groups["volume"]
219+
reporter = groups["reporter"]
220+
page = groups["page"]
221+
except (KeyError, TypeError, IndexError):
222+
return ValidationResult(False,
223+
reason="Error from parsing citations with eyecite.")
224+
225+
normalized_reporter = reporter.lower().replace(".", "")
226+
227+
# TODO: if (vol, normalized_reporter, page) not in citation_index,
228+
# resort to other measures, like fuzzy matching names and/or LLM-as-a-judge
229+
230+
return (vol, normalized_reporter, page) in citation_index
231+
232+
# endregion
233+
234+
235+
# region citation_exists function
236+
237+
def citation_exists(ctx: Context, database: list[dict]) -> ValidationResult:
238+
"""
239+
Validator:
240+
Ensures that every cite.case.law URL in the LLM output corresponds to a real case in the provided case metadata database.
241+
242+
Args:
243+
ctx: Mellea runtime context containing the last LLM output.
244+
database: Parsed caselaw metadata database of JSON objects.
245+
246+
Returns:
247+
ValidationResult indicating pass/fail.
248+
"""
249+
if ctx is None:
250+
return ValidationResult(False, reason="No context provided in output.")
251+
252+
last_output = ctx.last_output()
253+
254+
if last_output is None:
255+
return ValidationResult(False, reason="No last output found in context.")
256+
257+
urls_or_error = text_to_urls(last_output)
258+
259+
# text_to_urls may return a ValidationResult (error condition)
260+
if isinstance(urls_or_error, ValidationResult):
261+
return urls_or_error
262+
263+
# List of urls of citations found in the LLM output
264+
output_citation_urls = list(urls_or_error.keys())
265+
266+
if output_citation_urls is None or output_citation_urls == []:
267+
# No citations, so trivially valid
268+
return ValidationResult(True, reason="No citations found.")
269+
270+
database_ids = collect_ids_in_database(database)
271+
272+
for url in output_citation_urls:
273+
# If this URL is Caselaw, do direct comparison within database by using case id
274+
if "cite.case.law" in url:
275+
try:
276+
metadata_url = extract_case_metadata_url(url)
277+
278+
# Check if extract_case_metadata_url returns a ValidationResult and propagate it
279+
if isinstance(metadata_url, ValidationResult):
280+
return metadata_url
281+
282+
metadata = metadata_url_to_json(metadata_url)
283+
case_id = metadata["id"]
284+
285+
except Exception as e:
286+
return ValidationResult(False, reason=f"Failed to retrieve metadata for {url}: {e}")
287+
288+
if case_id not in database_ids:
289+
return ValidationResult(False, reason=f"Case {case_id} not found in database")
290+
291+
# Non-case.law citations: pass into Eyecite and see if citations match
292+
else:
293+
# TODO: This logic might need some reworking because for cases where a match
294+
# cannot be found, we cannot verify this citation, but we also cannot disprove it
295+
# (due to factors like reporter names varying, citation lists in the database
296+
# not being exhaustive, and parsing being lossy or ambiguous)
297+
text = urls_or_error[url]
298+
result = non_caselaw_citation_exists(text, database)
299+
300+
# Case 1: hard failure -> propagate
301+
if isinstance(result, ValidationResult):
302+
return result
303+
304+
# Case 2: deterministic match found -> OK
305+
if result is True:
306+
continue
307+
308+
# Case 3: result is False -> inconclusive -> do NOT fail
309+
# Explicitly allow this to pass for now
310+
continue
311+
312+
return ValidationResult(True, reason="All case.law citations verified; non-case.law citations did not fail verification.")
313+
314+
315+
class CaseNameExistsInDatabase(Requirement):
316+
"""
317+
Requirement wrapper for Mellea that ensures case citations in LLM output
318+
refer to real cases in the provided metadata database.
319+
"""
320+
def __init__(self, case_metadata: list[dict]):
321+
self._case_metadata = case_metadata
322+
super().__init__(
323+
description="The case name should exist in the provided case metadata database.",
324+
validation_fn=lambda ctx: citation_exists(ctx, self._case_metadata),
325+
)
326+
327+
# endregion

0 commit comments

Comments
 (0)