Skip to content

Commit fa8af37

Browse files
authored
Expose templates of LLM grader instances and default templates of LLM grader classes. (#124)
* Expose templates of grader instances and default templates of grader classes. * Resolve code review feedbacks. * Update function argument type annotation.
1 parent 020f69b commit fa8af37

File tree

4 files changed

+80
-17
lines changed

4 files changed

+80
-17
lines changed

openjudge/graders/agent/action/action_alignment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ class ActionAlignmentGrader(LLMGrader):
170170
>>> print(f"Score: {result.score}") # Expected: 1.0
171171
"""
172172

173+
DEFAULT_TEMPLATE = DEFAULT_ACTION_ALIGNMENT_TEMPLATE
174+
173175
def __init__(
174176
self,
175177
model: BaseChatModel | dict,
176-
template: Optional[PromptTemplate] = DEFAULT_ACTION_ALIGNMENT_TEMPLATE,
178+
template: Optional[PromptTemplate] = DEFAULT_TEMPLATE,
177179
language: LanguageEnum = LanguageEnum.EN,
178180
strategy: BaseEvaluationStrategy | None = None,
179181
):
@@ -183,7 +185,7 @@ def __init__(
183185
Args:
184186
model: The chat model to use for evaluation, either as a BaseChatModel instance or config dict
185187
template: The prompt template for action alignment evaluation.
186-
Defaults to DEFAULT_ACTION_ALIGNMENT_TEMPLATE.
188+
Defaults to DEFAULT_TEMPLATE.
187189
language: The language for the evaluation prompt. Defaults to LanguageEnum.EN.
188190
strategy: The evaluation strategy to use. Defaults to DirectStrategy.
189191
"""
@@ -192,7 +194,7 @@ def __init__(
192194
mode=GraderMode.POINTWISE,
193195
description="Evaluate action alignment with plan",
194196
model=model,
195-
template=template or DEFAULT_ACTION_ALIGNMENT_TEMPLATE,
197+
template=template or self.DEFAULT_TEMPLATE,
196198
language=language,
197199
strategy=strategy,
198200
)

openjudge/graders/llm_grader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class LLMGrader(BaseGrader):
5353
callback (Callable): Function to process model response metadata.
5454
"""
5555

56+
# The default template value is just a placeholder.
57+
# Extended classes must set proper value to DEFAULT_TEMPLATE
58+
DEFAULT_TEMPLATE = PromptTemplate(messages={})
59+
5660
def __init__(
5761
self,
5862
model: BaseChatModel | dict,
@@ -108,6 +112,9 @@ def __init__(
108112
else:
109113
self.language = language
110114

115+
if not template:
116+
raise ValueError("Missing template argument value")
117+
111118
if isinstance(template, str):
112119
self.template = PromptTemplate(
113120
messages={
@@ -343,6 +350,15 @@ async def _aevaluate(self, **kwargs: Any) -> GraderScore | GraderRank:
343350
raise ValueError(f"Unsupported grader mode: {self.mode}")
344351
return result
345352

353+
def get_template(self, language: LanguageEnum = LanguageEnum.EN) -> Dict[str, Any]:
354+
"""Return the template of the specified language in this grader instance"""
355+
return self.template.get_prompt(language)
356+
357+
@classmethod
358+
def get_default_template(cls, language: LanguageEnum = LanguageEnum.EN) -> Dict[str, Any]:
359+
"""Return the default template of the specified language in this grader class"""
360+
return cls.DEFAULT_TEMPLATE.get_prompt(language)
361+
346362
@staticmethod
347363
def get_metadata() -> Dict[str, Any]:
348364
"""Return the docstring of the aevaluate method to explain how LLMGrader works with LLM."""

tests/graders/agent/action/test_action_alignment.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,28 @@ def test_initialization(self):
6161
assert grader.name == "action_alignment"
6262
assert grader.model == mock_model
6363

64+
language_template = grader.get_template(LanguageEnum.ZH)
65+
assert len(language_template) == 1
66+
assert "zh" in language_template
67+
template = language_template["zh"]
68+
assert len(template) == 1
69+
assert len(template[0]) == 2
70+
assert template[0]["role"] == "user"
71+
assert template[0]["content"].startswith(
72+
"你是一名分析智能体行为的专家。你的任务是评估智能体是否执行了与其声明的计划或推理一致的动作。"
73+
)
74+
75+
language_template = grader.get_default_template(LanguageEnum.EN)
76+
assert len(language_template) == 1
77+
assert "en" in language_template
78+
template = language_template["en"]
79+
assert len(template) == 1
80+
assert len(template[0]) == 2
81+
assert template[0]["role"] == "user"
82+
assert template[0]["content"].startswith(
83+
"You are an expert in analyzing agent behavior. Your task is to evaluate whether the agent executes an action that aligns with its stated plan or reasoning."
84+
)
85+
6486
@pytest.mark.asyncio
6587
async def test_successful_evaluation_aligned(self):
6688
"""Test successful evaluation with good alignment"""
@@ -156,12 +178,8 @@ async def test_error_handling(self):
156178
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
157179
RUN_QUALITY_TESTS = bool(OPENAI_API_KEY and OPENAI_BASE_URL)
158180

159-
pytestmark = pytest.mark.skipif(
160-
not RUN_QUALITY_TESTS,
161-
reason="Requires API keys and base URL to run quality tests",
162-
)
163-
164181

182+
@pytest.mark.skipif(not RUN_QUALITY_TESTS, reason="Requires API keys and base URL to run quality tests")
165183
@pytest.mark.quality
166184
class TestActionAlignmentGraderQuality:
167185
"""Quality tests for ActionAlignmentGrader - testing evaluation quality"""

tests/graders/test_llm_grader.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from openjudge.graders.llm_grader import LLMGrader
4343
from openjudge.graders.schema import GraderError
4444
from openjudge.models.openai_chat_model import OpenAIChatModel
45+
from openjudge.models.schema.prompt_template import LanguageEnum
4546
from openjudge.runner.grading_runner import GraderConfig, GradingRunner
4647

4748
# ==================== UNIT TESTS ====================
@@ -60,12 +61,18 @@ def test_initialization_failure_without_template(self):
6061
model=AsyncMock(),
6162
name="foo",
6263
)
64+
assert "Missing template argument value" in str(error_obj.value)
65+
66+
def test_initialization_failure_with_invalid_template_type(self):
67+
"""Test initialization failure without template"""
68+
with pytest.raises(ValueError) as error_obj:
69+
LLMGrader(model=AsyncMock(), name="foo", template=AsyncMock())
6370
assert "Template must be a str, list, dict or PromptTemplate object" in str(error_obj.value)
6471

6572
def test_initialization_with_string_template(self):
6673
"""Test successful initialization with string template"""
6774
mock_model = AsyncMock()
68-
template_str = """You're a LLM query answer relevance grader, you'll received Query/Response:
75+
template_str = """You're a LLM query answer relevance grader, you'll receive Query/Response:
6976
Query: {query}
7077
Response: {response}
7178
Please read query/response, if the Response answers the Query, return 1, return 0 if no.
@@ -98,7 +105,7 @@ def test_initialization_with_dict_template(self):
98105
},
99106
{
100107
"role": "user",
101-
"content": """You'll received Query/Response:
108+
"content": """You'll receive Query/Response:
102109
Query: {query}
103110
Response: {response}
104111
Please read query/response, if the Response answers the Query, return 1, return 0 if no.
@@ -139,7 +146,7 @@ def test_initialization_with_model_dict(self):
139146
"api_key": "test-key",
140147
}
141148

142-
template_str = """You're a LLM query answer relevance grader, you'll received Query/Response:
149+
template_str = """You're a LLM query answer relevance grader, you'll receive Query/Response:
143150
Query: {query}
144151
Response: {response}
145152
Please read query/response, if the Response answers the Query, return 1, return 0 if no.
@@ -158,8 +165,29 @@ def test_initialization_with_model_dict(self):
158165
)
159166

160167
assert grader.name == "test_llm_grader"
161-
assert isinstance(grader.model, OpenAIChatModel)
162168
# Note: We can't easily check the model config since it's private
169+
assert isinstance(grader.model, OpenAIChatModel)
170+
171+
language_template = grader.get_template()
172+
assert len(language_template) == 1
173+
assert LanguageEnum.EN in language_template
174+
templates = language_template[LanguageEnum.EN]
175+
assert len(templates) == 2
176+
for t in templates:
177+
assert len(t) == 2
178+
assert "role" in t
179+
assert "content" in t
180+
181+
if t["role"] == "system":
182+
assert (
183+
"You are a professional evaluation assistant. Please evaluate according to the user's requirements."
184+
in t["content"]
185+
)
186+
elif t["role"] == "user":
187+
assert "You're a LLM query answer relevance grader, you'll receive Query/Response" in t["content"]
188+
189+
default_template = grader.get_default_template()
190+
assert len(default_template) == 0
163191

164192
@pytest.mark.asyncio
165193
async def test_pointwise_evaluation_success(self):
@@ -217,7 +245,7 @@ async def test_listwise_evaluation_success(self):
217245
mock_model.achat = AsyncMock(return_value=mock_response)
218246

219247
# Create grader with template that follows the specification in docs
220-
template = """You're a LLM query answer ranking grader, you'll received Query and multiple Responses:
248+
template = """You're a LLM query answer ranking grader, you'll receive Query and multiple Responses:
221249
Query: {query}
222250
Responses:
223251
1. {response_1}
@@ -308,9 +336,8 @@ def test_serialization_methods(self):
308336
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
309337
RUN_QUALITY_TESTS = bool(OPENAI_API_KEY and OPENAI_BASE_URL)
310338

311-
pytestmark = pytest.mark.skipif(not RUN_QUALITY_TESTS, reason="Requires API keys and base URL to run quality tests")
312-
313339

340+
@pytest.mark.skipif(not RUN_QUALITY_TESTS, reason="Requires API keys and base URL to run quality tests")
314341
@pytest.mark.quality
315342
class TestLLMGraderQuality:
316343
"""Quality tests for LLMGrader - testing evaluation quality using golden dataset"""
@@ -361,7 +388,7 @@ def model(self):
361388
async def test_discriminative_power_with_runner(self, dataset, model):
362389
"""Test the grader's ability to distinguish between accurate and inaccurate responses (using Runner)"""
363390
# Create grader with real model following the specification in docs
364-
template = """You're a LLM query answer accuracy grader, you'll received Query/Response and Context:
391+
template = """You're a LLM query answer accuracy grader, you'll receive Query/Response and Context:
365392
Query: {query}
366393
Response: {response}
367394
Context: {context}
@@ -420,7 +447,7 @@ async def test_discriminative_power_with_runner(self, dataset, model):
420447
async def test_consistency_with_runner(self, dataset, model):
421448
"""Test grader evaluation consistency (using Runner)"""
422449
# Create grader with real model following the specification in docs
423-
template = """You're a LLM query answer accuracy grader, you'll received Query/Response and Context:
450+
template = """You're a LLM query answer accuracy grader, you'll receive Query/Response and Context:
424451
Query: {query}
425452
Response: {response}
426453
Context: {context}

0 commit comments

Comments
 (0)