|
11 | 11 | from dotenv import load_dotenv |
12 | 12 | from deepeval.test_case import LLMTestCase |
13 | 13 | from deepeval import evaluate |
| 14 | +from deepeval.models import GeminiModel |
14 | 15 |
|
15 | | -from auto_evaluation.src.models.vertex_ai import GoogleVertexAILangChain |
16 | 16 | from auto_evaluation.src.metrics.retrieval import ( |
17 | 17 | make_contextual_precision_metric, |
18 | 18 | make_contextual_recall_metric, |
@@ -42,7 +42,11 @@ def __init__(self, base_url: str, dataset: str, reranker_base_url: str = ""): |
42 | 42 | self.dataset = dataset |
43 | 43 | self.reranker_base_url = reranker_base_url |
44 | 44 | self.qns = preprocess.read_data(self.dataset) |
45 | | - self.eval_model = GoogleVertexAILangChain(model_name="gemini-1.5-pro-002") |
| 45 | + self.eval_model = GeminiModel( |
| 46 | + model_name="gemini-1.5-pro-002", |
| 47 | + project=os.getenv("GOOGLE_PROJECT_ID", ""), |
| 48 | + location=os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"), |
| 49 | + ) |
46 | 50 | self.log_dir = "logs" |
47 | 51 | os.makedirs(self.log_dir, exist_ok=True) |
48 | 52 | self.sanity_check() |
@@ -91,8 +95,8 @@ def evaluate(self, retriever: str): |
91 | 95 |
|
92 | 96 | # parallel evaluate |
93 | 97 | evaluate( |
94 | | - retrieval_tcs, |
95 | | - [precision, recall, hallucination], |
| 98 | + test_cases=retrieval_tcs, |
| 99 | + metrics=[precision, recall, hallucination], |
96 | 100 | print_results=False, |
97 | 101 | ) |
98 | 102 |
|
|
0 commit comments