Skip to content

Commit 9f40bb7

Browse files
committed
🚧 TMP
1 parent 6580044 commit 9f40bb7

File tree

7 files changed

+401
-135
lines changed

7 files changed

+401
-135
lines changed

haiku/llm_judges/__init__.py

Whitespace-only changes.

haiku/llm_judges/base.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
Abstract base class for haiku LLM judges.
3+
4+
Provides shared structure and style scoring. Subclasses implement
5+
score_single to define their own weighting strategies.
6+
"""
7+
8+
import re
9+
from abc import ABC, abstractmethod
10+
11+
import aiohttp
12+
13+
from llm_judges.deploy import VLLM_PORT
14+
from llm_judges.nlp import diff_syllables_count, segment_haiku_lines
15+
16+
17+
18+
# =============================================================================
19+
# Shared Prompt Template
20+
# =============================================================================
21+
22+
MODAL_VOCABS = [
23+
"modal",
24+
"volume"
25+
"function",
26+
"sandbox",
27+
"flash",
28+
"inference",
29+
"train",
30+
]
31+
32+
def generate_haiku_judge_prompt(prompt: str, response: str, label: str) -> str:
33+
modal_vocab_str = ", ".join(MODAL_VOCABS)
34+
35+
return f"""You are evaluating a haiku poem.
36+
37+
Score the response based on the following criteria:
38+
39+
Relevance (5 points total)
40+
- 5 points: if the central theme and punchline of the haiku is "{prompt}"
41+
- 3 points: if the response directly discusses "{prompt}" but it is not the central theme
42+
- 2 points: if the response is relevant to the topic "{prompt}" but very plain
43+
- 0 points: if the response is not relevant to the topic "{prompt}"
44+
45+
Poetic quality (5 points total)
46+
- 5 points: if the response makes sense, can be considered a poetic haiku, with a clear theme and punchline
47+
- 3 point: if the response makes sense, but is not very poetic
48+
- 1 point: if the response doesn't make sense
49+
- 0 points: if the response is not poetic and incoherent
50+
51+
Uses Modal vocabulary (5 points total): (modal vocab: {modal_vocab_str})
52+
- 5 points: if the response uses the above words in a way that is coherent and relevant to the topic "{prompt}"
53+
- 3 points: if the response uses the above words in a way that is not relevant to the topic "{prompt}"
54+
- 0 points: if the response does not use the above words
55+
56+
Better than the existing poem (5 points total):
57+
Given the existing poem, score the response by comparing its quality to the existing poem:
58+
{label}
59+
- 5 points: if the response is better than the poem "{label}".
60+
- 3 points: if the response is equal in quality to the poem "{label}".
61+
- 0 points: if the response is worse than the poem "{label}".
62+
63+
Add up the scores from the above criteria to get the total score.
64+
65+
--
66+
**Topic:** {prompt}
67+
68+
**Response to evaluate:**
69+
{response}
70+
---
71+
72+
Output ONLY a single number (0-20), nothing else."""
73+
74+
75+
class HaikuJudge(ABC):
76+
"""Abstract base class for haiku judges.
77+
78+
Shared scoring:
79+
- score_haiku_structure: 0-8 based on line count and syllable accuracy
80+
- score_haiku_style: 0-2 based on LLM evaluation of relevance and emotion
81+
82+
Subclasses implement score_single to combine these into a final [0, 1] score.
83+
"""
84+
85+
MAX_STRUCTURE_SCORE = 1
86+
MAX_STYLE_SCORE = 20
87+
88+
@property
89+
@abstractmethod
90+
def name(self) -> str:
91+
"""Short identifier for this judge, used as the Modal app/deployment name."""
92+
...
93+
94+
@staticmethod
95+
def score_syllable_line(diff: int, allow_off_by_one: bool = False) -> float:
96+
"""Score a single line's syllable count: 1 for exact, 0.5 for off-by-1, 0 otherwise."""
97+
if diff == 0:
98+
return 1
99+
elif diff == 1:
100+
return 0.5 if allow_off_by_one else 0
101+
return 0
102+
103+
@staticmethod
104+
def score_haiku_structure(response: str, cmudict: dict, allow_off_by_one: bool = False) -> float:
105+
"""Score haiku structure (0-1): 1/4 for 3 lines + up to 1/4 per line for syllables."""
106+
lines = segment_haiku_lines(response)
107+
score = 0.0
108+
fractional_multiplier = 0.25
109+
110+
if len(lines) == 3:
111+
score += fractional_multiplier
112+
113+
if len(lines) > 0:
114+
score += HaikuJudge.score_syllable_line(
115+
diff_syllables_count(lines[0], 5, cmudict), allow_off_by_one
116+
) * fractional_multiplier
117+
if len(lines) > 1:
118+
score += HaikuJudge.score_syllable_line(
119+
diff_syllables_count(lines[1], 7, cmudict), allow_off_by_one
120+
) * fractional_multiplier
121+
if len(lines) > 2:
122+
score += HaikuJudge.score_syllable_line(
123+
diff_syllables_count(lines[2], 5, cmudict), allow_off_by_one
124+
) * fractional_multiplier
125+
126+
return score
127+
128+
@staticmethod
129+
async def score_haiku_style(
130+
model_name: str,
131+
session: aiohttp.ClientSession,
132+
prompt: str,
133+
response: str,
134+
label: str,
135+
vllm_base_url: str = f"http://localhost:{VLLM_PORT}",
136+
) -> float:
137+
"""Score haiku style via LLM judge (0-1), or 0 on error."""
138+
judge_prompt = generate_haiku_judge_prompt(prompt, response, label)
139+
140+
try:
141+
async with session.post(
142+
f"{vllm_base_url}/v1/chat/completions",
143+
headers={"content-type": "application/json"},
144+
json={
145+
"model": model_name,
146+
"messages": [{"role": "user", "content": judge_prompt}],
147+
"max_tokens": 100,
148+
},
149+
) as resp:
150+
if resp.status != 200:
151+
error_text = await resp.text()
152+
print(f"vLLM error: {resp.status} - {error_text}")
153+
return 0
154+
155+
data = await resp.json()
156+
score_text = data["choices"][0]["message"]["content"].strip()
157+
print(f"Scored {response} with score {score_text}")
158+
159+
match = re.search(r"(\d+(?:\.\d+)?)", score_text)
160+
if match:
161+
score = float(match.group(1))
162+
return min(max(score, 0), 10) / 10
163+
return 0
164+
except Exception as e:
165+
print(f"Error scoring response: {e}")
166+
return 0
167+
168+
@abstractmethod
169+
async def score_single(
170+
self,
171+
model_name: str,
172+
session: aiohttp.ClientSession,
173+
prompt: str,
174+
response: str,
175+
label: str,
176+
cmudict: dict,
177+
) -> float:
178+
"""Score a single haiku. Returns a normalized score in [0, 1]."""
179+
...
180+

0 commit comments

Comments
 (0)