11# Standard
22from pathlib import Path
3- from typing import List , TypedDict
3+ from typing import List , Optional , TypedDict
44
55# Third Party
66from langchain_community .chat_models import ChatOpenAI
7+ from pandas import DataFrame , read_json
8+ from pydantic import BaseModel , ConfigDict , field_validator
79from ragas .evaluation import EvaluationDataset , EvaluationResult , RunConfig , evaluate
10+ from ragas .metrics import Metric
811from ragas .metrics ._domain_specific_rubrics import ( # the rubrics we must instantiate are located inside of a file marked as private
912 DEFAULT_WITH_REFERENCE_RUBRICS ,
1013 RubricsScore ,
1114)
12- import pandas as pd
1315
1416# Local
1517from .evaluator import Evaluator
18+ from .mt_bench_common import get_openai_client
1619
1720
1821class Sample (TypedDict ):
22+ """
23+ TypedDict of a sample that we accept when doing eval with Ragas.
24+ We specifically use TypedDict here to be flexible with the input data we accept.
25+ """
26+
1927 # question
2028 user_input : str
2129
2230 # model answer
23- response : str
31+ response : Optional [ str ]
2432
2533 # golden answer
2634 reference : str
2735
2836
37+ # default system prompt we'll use when none is provided. Make it private as we don't intend this to be a public object
38+ _DEFAULT_SYSTEM_PROMPT = """You are an advanced AI assistant designed to provide precise and accurate information.
39+ Your primary goal is to answer queries with the most up-to-date and factual information available.
40+ Focus on delivering clear, concise, and correct responses.
41+ If you're uncertain about any aspect of the query, state your level of confidence and provide the most accurate information you can.
42+ Your responses should prioritize accuracy over all other considerations."""
43+
44+ DEFAULT_SEED = 1337
45+ DEFAULT_JUDGE_MODEL = "gpt-4o"
46+
47+
48+ class ModelConfig (BaseModel ):
49+ model_config = ConfigDict (protected_namespaces = ())
50+
51+ # URL of the OpenAI server where the model shall be hosted.
52+ base_url : str
53+
54+ # name of the model to use.
55+ model_name : str
56+
57+ # The system prompt to be used when applying the chat template.
58+ system_prompt : str = _DEFAULT_SYSTEM_PROMPT
59+
60+ # We do NOT read from OPENAI_API_KEY for the student model for security reasons (e.g. sending the API key to another client)
61+ # To provide an OpenAI key, you must set it here; else the default is used.
62+ api_key : str = "no-api-key"
63+
64+ # "model randomness" aka likelihood of sampling something other than the likeliest token
65+ temperature : float = 0.0
66+
67+ # Max amount of tokens to generate.
68+ max_tokens : int = 768
69+
70+ # Random seed for reproducibility. Caution: this isn't supported by all model serving runtimes.
71+ seed : int = DEFAULT_SEED
72+
73+ @field_validator ("temperature" )
74+ @classmethod
75+ def check_temperature (cls , v : float ) -> float :
76+ if not 0.0 <= v <= 1.0 :
77+ raise ValueError ("temperature must be between 0.0 and 1.0" )
78+ return v
79+
80+
2981class RagasEvaluator (Evaluator ):
3082 # most basic implementation, we just assume that the user will bring the existing model responses
3183 name = "ragas"
3284
33- def __init__ (self ):
34- pass
85+ def __init__ (
86+ self ,
87+ student_model : ModelConfig | None = None ,
88+ run_config : RunConfig | None = None ,
89+ ):
90+ self .student_model = student_model
91+ self .run_config = run_config
3592
3693 def run (
37- self , dataset : List [Sample ] | Path = None , run_config : RunConfig | None = None
94+ self ,
95+ dataset : List [Sample ] | Path ,
96+ student_model : ModelConfig | None = None ,
97+ run_config : RunConfig | None = None ,
3898 ) -> EvaluationResult :
3999 """
40100 Evaluates the quality of model responses against a graded rubric.
41101
102+ When the `dataset` lacks the `response` field, then `student_model` must be provided
103+ in order to generate the answers.
104+
42105 Args:
43106 dataset (List[Sample] | Path):
44- List of model questions and answers
107+ Can be either a list of `Sample` objects or a path to a jsonl file containing
108+ records matching `Sample`.
109+ student_model: (StudentModelConfig):
110+ When this parameter is provided, we'll attempt to use the described model in order to
111+ generate the responses from the given list of questions.
45112 run_config (RunConfig | None, optional):
46113 Configuration to use when running evaluations. If none is provided, then
47114 a default one is created containing extremely permissive settings when handling
@@ -51,45 +118,98 @@ def run(
51118 Returns:
52119 EvaluationResult: The results of all evaluations performed by Ragas
53120 """
121+ student_model = student_model if student_model else self .student_model
122+ run_config = run_config if run_config else self .run_config
123+
54124 if not dataset :
55125 raise ValueError (
56126 "no dataset was provided, please specify the `dataset` argument"
57127 )
58- if isinstance (dataset , Path ):
59- input_ds = EvaluationDataset .from_pandas (
60- pd .read_json (dataset , lines = True , orient = "records" )
128+
129+ if type (dataset ) not in (list , Path ):
130+ raise TypeError (f"invalid type of dataset: { type (dataset )} " )
131+
132+ # ensure we are in the dataframe format
133+ input_df = None
134+ if isinstance (dataset , list ):
135+ input_df = DataFrame (dataset )
136+ elif isinstance (dataset , Path ):
137+ input_df = read_json (dataset , orient = "records" , lines = True )
138+
139+ # this should never happen, but pylint is not smart enough to detect it
140+ assert input_df is not None
141+
142+ need_to_generate_questions = "response" not in input_df .columns
143+ if need_to_generate_questions and not student_model :
144+ raise ValueError (
145+ "provided dataset doesn't contain the model `response`, but no `student_model` was provided for inference"
61146 )
62- elif isinstance ( dataset , list ):
63- input_ds = EvaluationDataset . from_list ( dataset )
64- else :
65- raise TypeError ( f"invalid type passed for dataset: { type ( dataset ) } " )
147+
148+ # if the student model was provided then we always generate regardless
149+ if student_model :
150+ input_df = self . _generate_answers_from_model ( input_df , student_model )
66151
67152 if not run_config :
68153 # we set extreme timeout/retry values by default since OpenAI tier-1 rate limits
69154 # are horrible and will result in half of our evaluation results being NaN or 0
70155 run_config = RunConfig (
71156 max_retries = 120 ,
72157 max_wait = 7200 ,
73- seed = 42 ,
158+ seed = DEFAULT_SEED ,
74159 timeout = 3600 ,
75160 )
76161
77- # default set of metrics
78- metrics = [
79- RubricsScore (
80- rubrics = DEFAULT_WITH_REFERENCE_RUBRICS ,
81- )
82- ]
162+ metrics = self ._get_metrics ()
163+ evaluation_ds = EvaluationDataset .from_pandas (input_df )
83164
84165 # we will be using gpt-4o for the foreseeable future, we hardcode this
85166 # for consistency of answers
86- critic_lm = ChatOpenAI (model = "gpt-4o" )
167+ critic_lm = ChatOpenAI (model = DEFAULT_JUDGE_MODEL )
87168 results = evaluate (
88- dataset = input_ds ,
169+ dataset = evaluation_ds ,
89170 batch_size = 4 ,
90171 run_config = run_config ,
91172 llm = critic_lm ,
92173 metrics = metrics ,
93174 show_progress = True ,
94175 )
95176 return results
177+
178+ def _generate_answers_from_model (
179+ self , questions : DataFrame , student_model : ModelConfig
180+ ) -> DataFrame :
181+ """
182+ Given a DataFrame containing `user_input` columns, generates responses from the given model
183+ and returns a new DataFrame containing its answers in the `response` column.
184+ """
185+ client = get_openai_client (
186+ model_api_base = student_model .base_url , api_key = student_model .api_key
187+ )
188+
189+ # initialize response to write into
190+ updated_df = questions .copy ()
191+ updated_df ["response" ] = ""
192+
193+ for i , qna in updated_df .iterrows ():
194+ messages = [
195+ student_model .system_prompt ,
196+ qna ["user_input" ],
197+ ]
198+ response = client .chat .completions .create (
199+ messages = messages ,
200+ model = student_model .model_name ,
201+ # specify the seed so we can at least try to have some reproducibility when the clients support it
202+ seed = 42 ,
203+ max_tokens = student_model .max_tokens ,
204+ temperature = student_model .temperature ,
205+ )
206+ updated_df .at [i , "response" ] = response .choices [0 ].message .content
207+ return updated_df
208+
209+ def _get_metrics (self ) -> List [Metric ]:
210+ # default set of metrics
211+ return [
212+ RubricsScore (
213+ rubrics = DEFAULT_WITH_REFERENCE_RUBRICS ,
214+ )
215+ ]
0 commit comments