Skip to content

Commit 31b9a8b

Browse files
nushibBesmira Nushi
andauthored
Benushi/offline model (#157)
- Added OfflineFileModel class in models.py which reads results from a precomputed file instead of calling an actual model. This can be used when results are computed before outside of Eureka, or from a multiagent system. - Added a model config in model_configs.py to show how to use this. This is called OFFLINE_MODEL_CONFIG Co-authored-by: Besmira Nushi <[email protected]>
1 parent 86b57ab commit 31b9a8b

File tree

4 files changed

+124
-6
lines changed

4 files changed

+124
-6
lines changed

eureka_ml_insights/configs/model_configs.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
RestEndpointModel,
2020
TogetherModel,
2121
TestModel,
22+
OfflineFileModel
2223
)
2324
from eureka_ml_insights.models.models import AzureOpenAIModel
2425

@@ -32,6 +33,17 @@
3233
# Test model
3334
TEST_MODEL_CONFIG = ModelConfig(TestModel, {})
3435

36+
OFFLINE_MODEL_CONFIG = ModelConfig(
37+
OfflineFileModel,
38+
{
39+
"model_name": "Teacher_Agent_V1",
40+
# This file contains the offline results from a model or agentic system
41+
# The file should contain at least the following fields:
42+
# "model_output", "prompt", and "data_repeat_id" for experiments that have several runs/repeats
43+
"file_path": r"your_offline_model_results.jsonl",
44+
},
45+
)
46+
3547
# Together models
3648
TOGETHER_SECRET_KEY_PARAMS = {
3749
"key_name": "your_togetherai_secret_key_name",
@@ -64,7 +76,6 @@
6476
)
6577

6678
# OpenAI models
67-
6879
OPENAI_SECRET_KEY_PARAMS = {
6980
"key_name": "your_openai_secret_key_name",
7081
"local_keys_path": "keys/keys.json",
@@ -104,7 +115,7 @@
104115
},
105116
)
106117

107-
OAI_O1_PREVIEW_AUZRE_CONFIG = ModelConfig(
118+
OAI_O1_PREVIEW_AZURE_CONFIG = ModelConfig(
108119
AzureOpenAIOModel,
109120
{
110121
"model_name": "o1-preview",

eureka_ml_insights/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Phi4HFModel,
2020
RestEndpointModel,
2121
TestModel,
22+
OfflineFileModel,
2223
VLLMModel,
2324
TogetherModel
2425
)
@@ -44,6 +45,7 @@
4445
LocalVLLMModel,
4546
RestEndpointModel,
4647
TestModel,
48+
OfflineFileModel,
4749
VLLMModel,
4850
TogetherModel
4951
]

eureka_ml_insights/models/models.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This module contains classes for interacting with various models, including API-based models and HuggingFace models."""
22

33
import json
4+
import pandas as pd
45
import logging
56
import random
67
import threading
@@ -87,7 +88,6 @@ def get_api_key(self):
8788
self.api_key = get_secret(**self.secret_key_params)
8889
return self.api_key
8990

90-
9191
@dataclass
9292
class EndpointModel(Model):
9393
"""This class is used to interact with API-based models."""
@@ -180,6 +180,109 @@ def generate(self, query_text, *args, **kwargs):
180180
def handle_request_error(self, e):
181181
raise NotImplementedError
182182

183+
@dataclass
184+
class OfflineFileModel(Model):
185+
"""This class is used to read pre-generated model/system results via a local file."""
186+
187+
file_path: str = None
188+
model_name: str = None
189+
df_results: pd.DataFrame = None
190+
191+
def __post_init__(self):
192+
if not self.file_path:
193+
raise ValueError("file_path must be provided.")
194+
if not self.model_name:
195+
raise ValueError("Model name must be provided as additional information on the model/system that was previous used for generating the file in file_path.")
196+
197+
# Load the results from the file into a DataFrame that can be reused for reading all individual results later.
198+
try:
199+
self.df_results = pd.read_json(self.file_path, lines=True)
200+
except FileNotFoundError:
201+
raise FileNotFoundError(f"Error: File '{self.file_path}' not found.")
202+
except ValueError as ve:
203+
raise ValueError(f"Error reading JSON from '{self.file_path}': {ve}")
204+
except Exception as e:
205+
print(f"An unexpected error occurred: {e}")
206+
207+
# Check for required columns in the file
208+
required_columns = {"prompt", "model_output"}
209+
missing_columns = required_columns - set(self.df_results.columns)
210+
if missing_columns:
211+
raise ValueError(f"Error: Missing required columns in file_path: {missing_columns}")
212+
return None
213+
214+
def generate(self, query_text, *args, **kwargs):
215+
"""
216+
Reads the file from file_path to retrieve the model response.
217+
args:
218+
query_text (str): the text prompt to generate the response.
219+
data_repeat_id (str): the id of the repeat for the same prompt, if the initial file has multiple repeats for the same prompt.
220+
returns:
221+
response_dict (dict): a dictionary containing the model_output, is_valid, response_time, and n_output_tokens,
222+
and any other relevant information returned by the model.
223+
"""
224+
response_dict = {}
225+
if hasattr(self, "system_message") and self.system_message:
226+
if "system_message" in kwargs:
227+
logging.warning(
228+
"Warning: System message is passed via the dataloader but will not be used because the inference results are precomputed offline in file_path."
229+
)
230+
kwargs["system_message"] = self.system_message
231+
232+
if hasattr(self, "query_images") and self.system_message:
233+
if "query_images" in kwargs:
234+
logging.warning(
235+
"Warning: Images are not yet supported for this model class."
236+
)
237+
kwargs["query_images"] = self.query_images
238+
239+
if hasattr(self, "chat_mode") and self.chat_mode:
240+
if "chat_mode" in kwargs:
241+
logging.warning(
242+
"Warning: Chat mode is not supported for this model class."
243+
)
244+
245+
model_output = None
246+
is_valid = False
247+
response_time = 0 # This is a dummy value, as the response time is not available for offline files.
248+
n_output_tokens = None
249+
250+
try:
251+
model_response = self.get_response(query_text, kwargs.get("data_repeat_id", None))
252+
model_output = model_response["model_output"]
253+
is_valid = model_response["is_valid"]
254+
except Exception as e:
255+
logging.warning("Warning: ")
256+
257+
response_dict.update(
258+
{
259+
"is_valid": is_valid,
260+
"model_output": model_output,
261+
"response_time": response_time,
262+
"n_output_tokens": n_output_tokens or self.count_tokens(model_output, is_valid),
263+
}
264+
)
265+
return response_dict
266+
267+
def get_response(self, target_prompt, target_repeat_id):
268+
if target_repeat_id is None:
269+
filtered_df = self.df_results[(self.df_results['prompt'] == target_prompt)]
270+
else:
271+
filtered_df = self.df_results[(self.df_results['data_repeat_id'] == target_repeat_id) & (self.df_results['prompt'] == target_prompt)]
272+
273+
274+
# Check if a matching record exists
275+
if not filtered_df.empty:
276+
if len(filtered_df) > 1:
277+
logging.warning(f"Warning: More than one matching record found ({len(filtered_df)} records). Returning the first one.")
278+
model_output = str(filtered_df.iloc[0]['model_output'])
279+
# If the model output is empty, return None and is_valid as False
280+
if len(model_output) == 0:
281+
return {"model_output": None, "is_valid": False}
282+
return {"model_output": filtered_df.iloc[0]['model_output'], "is_valid": True}
283+
else:
284+
return {"model_output": None, "is_valid": False}
285+
183286

184287
@dataclass
185288
class RestEndpointModel(EndpointModel, KeyBasedAuthMixIn):

eureka_ml_insights/user_configs/aime.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
SequenceTransform,
3434
)
3535
from eureka_ml_insights.data_utils.aime_utils import AIMEExtractAnswer
36-
from eureka_ml_insights.data_utils.data import DataLoader
36+
from eureka_ml_insights.data_utils.data import MMDataLoader
3737
from eureka_ml_insights.metrics.aime_metrics import NumericMatch
3838
from eureka_ml_insights.metrics.reports import (
3939
BiLevelAggregator,
@@ -86,8 +86,10 @@ def configure_pipeline(
8686
component_type=Inference,
8787
model_config=model_config,
8888
data_loader_config=DataSetConfig(
89-
DataLoader,
90-
{"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl")},
89+
MMDataLoader,
90+
{
91+
"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl"),
92+
},
9193
),
9294
output_dir=os.path.join(self.log_dir, "inference_result"),
9395
resume_from=resume_from,

0 commit comments

Comments
 (0)