8
8
9
9
from pydantic_ai import Agent , models
10
10
11
- __all__ = ('GradingOutput' , 'judge_input_output' , 'judge_output' )
11
+ __all__ = ('GradingOutput' , 'judge_input_output' , 'judge_output' , 'set_default_judge_model' )
12
+
13
+
14
+ _default_model : models .Model | models .KnownModelName = 'openai:gpt-4o'
12
15
13
16
14
17
class GradingOutput (BaseModel , populate_by_name = True ):
@@ -41,11 +44,15 @@ class GradingOutput(BaseModel, populate_by_name=True):
41
44
42
45
43
46
async def judge_output (
44
- output : Any , rubric : str , model : models .Model | models .KnownModelName = 'openai:gpt-4o'
47
+ output : Any , rubric : str , model : models .Model | models .KnownModelName | None = None
45
48
) -> GradingOutput :
46
- """Judge the output of a model based on a rubric."""
49
+ """Judge the output of a model based on a rubric.
50
+
51
+ If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
52
+ but this can be changed using the `set_default_judge_model` function.
53
+ """
47
54
user_prompt = f'<Output>\n { _stringify (output )} \n </Output>\n <Rubric>\n { rubric } \n </Rubric>'
48
- return (await _judge_output_agent .run (user_prompt , model = model )).data
55
+ return (await _judge_output_agent .run (user_prompt , model = model or _default_model )).data
49
56
50
57
51
58
_judge_input_output_agent = Agent (
@@ -72,11 +79,24 @@ async def judge_output(
72
79
73
80
74
81
async def judge_input_output (
75
- inputs : Any , output : Any , rubric : str , model : models .Model | models .KnownModelName = 'openai:gpt-4o'
82
+ inputs : Any , output : Any , rubric : str , model : models .Model | models .KnownModelName | None = None
76
83
) -> GradingOutput :
77
- """Judge the output of a model based on the inputs and a rubric."""
84
+ """Judge the output of a model based on the inputs and a rubric.
85
+
86
+ If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
87
+ but this can be changed using the `set_default_judge_model` function.
88
+ """
78
89
user_prompt = f'<Input>\n { _stringify (inputs )} \n </Input>\n <Output>\n { _stringify (output )} \n </Output>\n <Rubric>\n { rubric } \n </Rubric>'
79
- return (await _judge_input_output_agent .run (user_prompt , model = model )).data
90
+ return (await _judge_input_output_agent .run (user_prompt , model = model or _default_model )).data
91
+
92
+
93
+ def set_default_judge_model (model : models .Model | models .KnownModelName ) -> None : # pragma: no cover
94
+ """Set the default model used for judging.
95
+
96
+ This model is used if `None` is passed to the `model` argument of `judge_output` and `judge_input_output`.
97
+ """
98
+ global _default_model
99
+ _default_model = model
80
100
81
101
82
102
def _stringify (value : Any ) -> str :
0 commit comments