Skip to content

Commit 8eb7f14

Browse files
saikolasaniSai Kolasani
andauthored
[Ragulate] Save Results + Update Model Providers (#579)
* wip * fix formatting --------- Co-authored-by: Sai Kolasani <[email protected]>
1 parent 94cd787 commit 8eb7f14

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,7 @@
66

77
from tqdm import tqdm
88
from trulens_eval import Tru, TruChain
9-
from trulens_eval.feedback.provider import (
10-
AzureOpenAI,
11-
Bedrock,
12-
Huggingface,
13-
Langchain,
14-
LiteLLM,
15-
OpenAI,
16-
)
17-
from trulens_eval.feedback.provider.base import LLMProvider
9+
from trulens_eval.feedback.provider import AzureOpenAI, Huggingface, LLMProvider, OpenAI
1810
from trulens_eval.schema.feedback import FeedbackMode, FeedbackResultStatus
1911
from typing_extensions import override
2012

@@ -131,6 +123,16 @@ def start_evaluation(self):
131123
self._tru.start_evaluator(disable_tqdm=True)
132124
self._evaluation_running = True
133125

126+
def export_results(self):
127+
"""Export results."""
128+
for dataset_name in self._queries:
129+
records, _feedback_names = self._tru.get_records_and_feedback(
130+
app_ids=[dataset_name]
131+
)
132+
133+
# Export to JSON
134+
records.to_json(f"{self._name}_{dataset_name}_results.json")
135+
134136
def stop_evaluation(self, loc: str):
135137
"""Stop evaluation."""
136138
if self._evaluation_running:
@@ -143,6 +145,7 @@ def stop_evaluation(self, loc: str):
143145
logger.exception("issue stopping evaluator")
144146
finally:
145147
self._progress.close()
148+
self.export_results()
146149

147150
def update_progress(self, query_change: int = 0):
148151
"""Update progress bar."""
@@ -176,12 +179,6 @@ def get_provider(self) -> LLMProvider:
176179
return OpenAI(model_engine=model_name)
177180
if llm_provider == "azureopenai":
178181
return AzureOpenAI(deployment_name=model_name)
179-
if llm_provider == "bedrock":
180-
return Bedrock(model_id=model_name)
181-
if llm_provider == "litellm":
182-
return LiteLLM(model_engine=model_name)
183-
if llm_provider == "Langchain":
184-
return Langchain(model_engine=model_name)
185182
if llm_provider == "huggingface":
186183
return Huggingface(name=model_name)
187184
raise ValueError(f"Unsupported provider: {llm_provider}")

0 commit comments

Comments
 (0)