11# # SPDX-License-Identifier: Apache-2.0
22# Standard
33from pathlib import Path
4- from typing import List , Optional , TypedDict
4+ from typing import TYPE_CHECKING , List , Optional , TypedDict
55
66# Third Party
77from langchain_community .chat_models import ChatOpenAI
88from openai import Client as OpenAIClient
9+ from openai .types .chat import ChatCompletionMessageParam
910from pandas import DataFrame , read_json
10- from pydantic import BaseModel , ConfigDict , field_validator
11+ from pydantic import BaseModel , ConfigDict , Field
1112from ragas .evaluation import EvaluationDataset , EvaluationResult , RunConfig , evaluate
1213from ragas .metrics import Metric
1314from ragas .metrics ._domain_specific_rubrics import ( # the rubrics we must instantiate are located inside of a file marked as private
1718
1819# Local
1920from .evaluator import Evaluator
21+ from .logger_config import setup_logger
22+
23+ logger = setup_logger (__name__ )
2024
2125
2226class Sample (TypedDict ):
@@ -56,21 +60,14 @@ class ModelConfig(BaseModel):
5660 system_prompt : str = _DEFAULT_SYSTEM_PROMPT
5761
5862 # "model randomness" aka likelihood of sampling something other than the likeliest token
59- temperature : float = 0.0
63+ temperature : float = Field ( default = 0.0 , le = 1.0 , ge = 0.0 )
6064
6165 # Max amount of tokens to generate.
6266 max_tokens : int = 768
6367
6468 # Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
6569 seed : int = DEFAULT_SEED
6670
67- @field_validator ("temperature" )
68- @classmethod
69- def check_temperature (cls , v : float ) -> float :
70- if not 0.0 <= v <= 1.0 :
71- raise ValueError ("temperature must be between 0.0 and 1.0" )
72- return v
73-
7471
7572class RagasEvaluator (Evaluator ):
7673 # most basic implementation, we just assume that the user will bring the existing model responses
@@ -80,18 +77,42 @@ def __init__(
8077 self ,
8178 student_model : ModelConfig | None = None ,
8279 run_config : RunConfig | None = None ,
83- openai_client : OpenAIClient | None = None ,
80+ student_openai_client : OpenAIClient | None = None ,
81+ judge_model_name : str = DEFAULT_JUDGE_MODEL ,
82+ judge_openai_api_key : str | None = None ,
8483 ):
8584 self .student_model = student_model
8685 self .run_config = run_config
87- self .openai_client = openai_client
86+ self .student_openai_client = student_openai_client
87+ self .judge_model_name = judge_model_name
88+ self .judge_openai_api_key = judge_openai_api_key
89+
90+ @staticmethod
91+ def _validate_dataset (df : DataFrame ):
92+ """
93+ Validates whether or not the given `df` is a valid dataset of `Sample` objects.
94+
95+ Args:
96+ df (DataFrame): DataFrame containing the dataset to be evaluated.
97+ """
98+ # We have to hardcode these fields because the automated way of resolving the required fields from a TypedDict
99+ # is only included by default in Python3.11+. For earlier versions, the `typing_extensions` package is required.
100+ # See: https://docs.python.org/3/whatsnew/3.11.html#pep-655-marking-individual-typeddict-items-as-required-or-not-required
101+ required_keys = {"user_input" , "reference" }
102+ missing_keys = required_keys - set (df .columns )
103+ if missing_keys :
104+ raise ValueError (
105+ f"invalid dataset provided, missing the following keys: { ', ' .join (missing_keys )} "
106+ )
88107
89108 def run (
90109 self ,
91110 dataset : List [Sample ] | Path ,
92111 student_model : ModelConfig | None = None ,
93112 run_config : RunConfig | None = None ,
94- openai_client : OpenAIClient | None = None ,
113+ student_openai_client : OpenAIClient | None = None ,
114+ judge_model_name : str | None = None ,
115+ judge_openai_api_key : str | None = None ,
95116 ) -> EvaluationResult :
96117 """
97118 Evaluates the quality of model responses against a graded rubric.
@@ -111,21 +132,31 @@ def run(
111132 a default one is created containing extremely permissive settings when handling
112133 timeouts. This is because by default, OpenAI tier-1 usage accounts have very high
113134 rate limits resulting in heavy throttling during evaluations.
114- openai_client (openai.Client | None, optional):
135+ student_openai_client (openai.Client | None, optional):
115136 The client to use when generating questions from the student model, must be compatible with the OpenAI API.
116137 This field is required when `student_model` is provided.
138+ judge_model_name (str | None, optional):
139+ Name of the OpenAI model to use as the judge model. Defaults to "gpt-4o" when none is specified.
140+ judge_openai_api_key (str | None, optional):
141+ The API key to use for evaluating the given dataset. When this isn't provided, `OPENAI_API_KEY` is read instead.
142+
117143
118144 Returns:
119145 EvaluationResult: The results of all evaluations performed by Ragas
120146 """
147+ judge_model_name = (
148+ judge_model_name if judge_model_name else self .judge_model_name
149+ )
150+ judge_openai_api_key = (
151+ judge_openai_api_key if judge_openai_api_key else self .judge_openai_api_key
152+ )
121153 student_model = student_model if student_model else self .student_model
122154 run_config = run_config if run_config else self .run_config
123- openai_client = openai_client if openai_client else self .openai_client
124-
125- if not dataset :
126- raise ValueError (
127- "no dataset was provided, please specify the `dataset` argument"
128- )
155+ student_openai_client = (
156+ student_openai_client
157+ if student_openai_client
158+ else self .student_openai_client
159+ )
129160
130161 # ensure we are in the dataframe format
131162 input_df = None
@@ -137,22 +168,30 @@ def run(
137168 raise TypeError (f"invalid type of dataset: { type (dataset )} " )
138169
139170 # this should never happen, but pylint is not smart enough to detect it
140- assert input_df is not None
171+ if TYPE_CHECKING :
172+ assert input_df is not None
173+
174+ # ensure the dataset is in the format we expect it
175+ self ._validate_dataset (input_df )
141176
142177 need_to_generate_questions = "response" not in input_df .columns
143- if need_to_generate_questions and ( not student_model or not openai_client ) :
144- raise ValueError (
145- "provided dataset doesn't contain the model `response`, but either `student_model` or `openai_client` wasn't provided for inference "
178+ if need_to_generate_questions :
179+ logger . debug (
180+ "`response` is missing in the input dataframe columns, generating questions from the model is required. "
146181 )
182+ if not student_model or not student_openai_client :
183+ raise ValueError (
184+ "provided dataset doesn't contain the model `response`, but either `student_model` or `student_openai_client` wasn't provided for inference"
185+ )
147186
148187 # if the student model was provided then we always generate regardless
149188 if student_model :
150- if not openai_client :
189+ if not student_openai_client :
151190 raise ValueError (
152- "`student_model` was specified but `openai_client ` was not provided"
191+ "`student_model` was specified but `student_openai_client ` was not provided"
153192 )
154193 input_df = self ._generate_answers_from_model (
155- input_df , student_model , openai_client
194+ input_df , student_model , student_openai_client
156195 )
157196
158197 if not run_config :
@@ -170,7 +209,8 @@ def run(
170209
171210 # we will be using gpt-4o for the foreseeable future, we hardcode this
172211 # for consistency of answers
173- critic_lm = ChatOpenAI (model = DEFAULT_JUDGE_MODEL )
212+
213+ critic_lm = ChatOpenAI (model = judge_model_name , api_key = judge_openai_api_key )
174214 results = evaluate (
175215 dataset = evaluation_ds ,
176216 batch_size = 4 ,
@@ -185,7 +225,7 @@ def _generate_answers_from_model(
185225 self ,
186226 questions : DataFrame ,
187227 student_model : ModelConfig ,
188- openai_client : OpenAIClient ,
228+ student_openai_client : OpenAIClient ,
189229 ) -> DataFrame :
190230 """
191231 Given a DataFrame containing `user_input` columns, generates responses from the given model
@@ -196,11 +236,14 @@ def _generate_answers_from_model(
196236 updated_df ["response" ] = ""
197237
198238 for i , qna in updated_df .iterrows ():
199- messages = [
200- student_model .system_prompt ,
201- qna ["user_input" ],
239+ messages : List [ChatCompletionMessageParam ] = [
240+ {
241+ "role" : "system" ,
242+ "content" : student_model .system_prompt ,
243+ },
244+ {"role" : "user" , "content" : qna ["user_input" ]},
202245 ]
203- response = openai_client .chat .completions .create (
246+ response = student_openai_client .chat .completions .create (
204247 messages = messages ,
205248 model = student_model .model_name ,
206249 # specify the seed so we can at least try to have some reproducibility when the clients support it
@@ -211,7 +254,8 @@ def _generate_answers_from_model(
211254 updated_df .at [i , "response" ] = response .choices [0 ].message .content
212255 return updated_df
213256
214- def _get_metrics (self ) -> List [Metric ]:
257+ @staticmethod
258+ def _get_metrics () -> List [Metric ]:
215259 # default set of metrics
216260 return [
217261 RubricsScore (
0 commit comments