Skip to content

Commit d58cd0d

Browse files
committed
format code
1 parent 67b6eec commit d58cd0d

File tree

5 files changed

+105
-87
lines changed

5 files changed

+105
-87
lines changed

evaluation/scripts/PrefEval/pref_eval.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import pandas as pd
1010
from dotenv import load_dotenv
11-
from openai import OpenAI
11+
from openai import OpenAI
1212

1313
load_dotenv()
1414

@@ -21,17 +21,16 @@
2121

2222

2323
async def call_gpt4o_mini_async(client: OpenAI, prompt: str) -> str:
24-
2524
messages = [{"role": "user", "content": prompt}]
2625

2726
try:
2827
response = await asyncio.to_thread(
29-
client.chat.completions.create,
30-
model="gpt-4o-mini",
28+
client.chat.completions.create,
29+
model="gpt-4o-mini",
3130
messages=messages,
3231
temperature=0,
3332
max_tokens=500,
34-
timeout=30.0
33+
timeout=30.0,
3534
)
3635
return response.choices[0].message.content
3736
except Exception as e:
@@ -45,7 +44,7 @@ def parse_xml_response(response: str, tag: str) -> str:
4544

4645

4746
async def evaluate_violate_preference_async(
48-
client: OpenAI, preference: str, question: str, response: str
47+
client: OpenAI, preference: str, question: str, response: str
4948
) -> Dict[str, str]:
5049
prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's response violates the user's stated preference.
5150
Evaluate the response based on these criteria:
@@ -69,15 +68,15 @@ async def evaluate_violate_preference_async(
6968
<explanation>[1 very short sentence explanation]</explanation>
7069
<answer>[Yes/No]</answer>"""
7170

72-
api_response = await call_gpt4o_mini_async(client, prompt)
71+
api_response = await call_gpt4o_mini_async(client, prompt)
7372
return {
7473
"explanation": parse_xml_response(api_response, "explanation"),
7574
"answer": parse_xml_response(api_response, "answer"),
7675
}
7776

7877

7978
async def evaluate_acknowledge_preference_async(
80-
client: OpenAI, question: str, response: str
79+
client: OpenAI, question: str, response: str
8180
) -> Dict[str, str]:
8281
prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query.
8382
@@ -99,15 +98,15 @@ async def evaluate_acknowledge_preference_async(
9998
Examine the response meticulously and answer. Please answer in this exact XML format without any additional text:
10099
<preference>[quote of the sentence that acknowledges/mentions what the preference is; leave it blank if there is none]</preference>
101100
<answer>[Yes/No]</answer>"""
102-
api_response = await call_gpt4o_mini_async(client, prompt)
101+
api_response = await call_gpt4o_mini_async(client, prompt)
103102
return {
104103
"preference_mention": parse_xml_response(api_response, "preference"),
105104
"answer": parse_xml_response(api_response, "answer"),
106105
}
107106

108107

109108
async def evaluate_hallucinate_preference_async(
110-
client: OpenAI, preference: str, restatement: str
109+
client: OpenAI, preference: str, restatement: str
111110
) -> Dict[str, str]:
112111
if not restatement.strip():
113112
return {"explanation": "No restatement provided by assistant", "answer": "No"}
@@ -132,15 +131,15 @@ async def evaluate_hallucinate_preference_async(
132131
<explanation>[1 short sentence explanation]</explanation>
133132
<answer>[Yes/No]</answer>"""
134133

135-
api_response = await call_gpt4o_mini_async(client, prompt)
134+
api_response = await call_gpt4o_mini_async(client, prompt)
136135
return {
137136
"explanation": parse_xml_response(api_response, "explanation"),
138137
"answer": parse_xml_response(api_response, "answer"),
139138
}
140139

141140

142141
async def evaluate_helpful_response_async(
143-
client: OpenAI, question: str, response: str
142+
client: OpenAI, question: str, response: str
144143
) -> Dict[str, str]:
145144
prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant provides any substantive response to the user's query.
146145
Evaluate the response based on these stringent criteria:
@@ -172,7 +171,7 @@ async def evaluate_helpful_response_async(
172171
<explanation>[1 very short sentence explanation]</explanation>
173172
<answer>[Yes/No]</answer>"""
174173

175-
api_response = await call_gpt4o_mini_async(client, prompt)
174+
api_response = await call_gpt4o_mini_async(client, prompt)
176175
return {
177176
"explanation": parse_xml_response(api_response, "explanation"),
178177
"answer": parse_xml_response(api_response, "answer"),
@@ -197,9 +196,7 @@ def classify_error_type(evaluation_results: Dict[str, Any]) -> str:
197196
return "Personalized Response"
198197

199198

200-
async def process_line(
201-
line: str, client: OpenAI, semaphore: asyncio.Semaphore
202-
) -> Dict[str, Any]:
199+
async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> Dict[str, Any]:
203200
async with semaphore:
204201
data = json.loads(line.strip())
205202
preference = data["preference"]
@@ -258,7 +255,7 @@ def generate_excel_summary(
258255
avg_search_time: float,
259256
avg_context_tokens: float,
260257
avg_add_time: float,
261-
model_name: str = "gpt-4o-mini",
258+
model_name: str = "gpt-4o-mini",
262259
):
263260
print(f"Generating Excel summary at {OUTPUT_EXCEL_FILE}...")
264261

@@ -280,7 +277,7 @@ def get_pct(key):
280277
"Personalized Response\n个性化回答": [personalized_pct / 100],
281278
"context token": [avg_context_tokens],
282279
"Time添加": [f"{avg_add_time:.2f}s"],
283-
"Time搜索": [f"{avg_search_time:.2f}s"]
280+
"Time搜索": [f"{avg_search_time:.2f}s"],
284281
}
285282

286283
df = pd.DataFrame(data)
@@ -355,9 +352,9 @@ async def main(concurrency_limit: int):
355352
context_tokens = metrics.get("memory_tokens_used")
356353
add_time = metrics.get("add_memories_duration_seconds")
357354

358-
all_metrics_valid = (search_time is not None and
359-
add_time is not None and
360-
context_tokens is not None)
355+
all_metrics_valid = (
356+
search_time is not None and add_time is not None and context_tokens is not None
357+
)
361358

362359
if all_metrics_valid:
363360
total_search_time += float(search_time)
@@ -375,7 +372,9 @@ async def main(concurrency_limit: int):
375372

376373
avg_search_time = (total_search_time / valid_metric_samples) if valid_metric_samples > 0 else 0
377374
avg_add_time = (total_add_time / valid_metric_samples) if valid_metric_samples > 0 else 0
378-
avg_context_tokens = (total_context_tokens / valid_metric_samples) if valid_metric_samples > 0 else 0
375+
avg_context_tokens = (
376+
(total_context_tokens / valid_metric_samples) if valid_metric_samples > 0 else 0
377+
)
379378

380379
try:
381380
generate_excel_summary(
@@ -398,4 +397,4 @@ async def main(concurrency_limit: int):
398397
)
399398
args = parser.parse_args()
400399

401-
asyncio.run(main(concurrency_limit=args.concurrency_limit))
400+
asyncio.run(main(concurrency_limit=args.concurrency_limit))

evaluation/scripts/PrefEval/pref_memos.py

Lines changed: 69 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1+
import argparse
2+
import concurrent.futures
3+
import json
14
import os
25
import sys
3-
import json
46
import time
5-
import uuid
67
import tiktoken
7-
import requests
88
from dotenv import load_dotenv
99
from openai import OpenAI
1010
from tqdm import tqdm
11-
import concurrent.futures
12-
import argparse
11+
1312
from irrelevant_conv import irre_10, irre_300
1413

1514
ROOT_DIR = os.path.dirname(
@@ -19,15 +18,10 @@
1918

2019
sys.path.insert(0, ROOT_DIR)
2120
sys.path.insert(0, EVAL_SCRIPTS_DIR)
22-
23-
from utils.client import memos_api_client
24-
2521
load_dotenv()
26-
2722
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
2823
BASE_URL = os.getenv("OPENAI_BASE_URL")
2924
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
30-
3125
tokenizer = tiktoken.get_encoding("cl100k_base")
3226

3327

@@ -60,11 +54,9 @@ def add_memory_for_line(
6054
mem_client.add(messages=conversation, user_id=user_id, conv_id=None)
6155
end_time_add = time.monotonic()
6256
add_duration = end_time_add - start_time_add
63-
57+
6458
original_data["user_id"] = user_id
65-
original_data["metrics"] = {
66-
"add_memories_duration_seconds": add_duration
67-
}
59+
original_data["metrics"] = {"add_memories_duration_seconds": add_duration}
6860
return original_data
6961

7062
except Exception as e:
@@ -98,7 +90,9 @@ def process_line_with_id(
9890
start_time_search = time.monotonic()
9991
relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value)
10092
search_memories_duration = time.monotonic() - start_time_search
101-
memories_str = "\n".join(f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]['memories'])
93+
memories_str = "\n".join(
94+
f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"]
95+
)
10296

10397
memory_tokens_used = len(tokenizer.encode(memories_str))
10498

@@ -111,12 +105,14 @@ def process_line_with_id(
111105
response = openai_client.chat.completions.create(model=MODEL_NAME, messages=messages)
112106
assistant_response = response.choices[0].message.content
113107
original_data["response"] = assistant_response
114-
115-
metrics_dict.update({
116-
"search_memories_duration_seconds": search_memories_duration,
117-
"memory_tokens_used": memory_tokens_used,
118-
"retrieved_memories_text": memories_str
119-
})
108+
109+
metrics_dict.update(
110+
{
111+
"search_memories_duration_seconds": search_memories_duration,
112+
"memory_tokens_used": memory_tokens_used,
113+
"retrieved_memories_text": memories_str,
114+
}
115+
)
120116
original_data["metrics"] = metrics_dict
121117

122118
return original_data
@@ -169,56 +165,72 @@ def main():
169165
print(f"Error: Input file '{args.input}' not found")
170166
return
171167

172-
mem_client = memos_api_client()
168+
from utils.client import memosApiClient
169+
170+
mem_client = memosApiClient()
173171

174172
if args.mode == "add":
175173
print(f"Running in 'add' mode. Ingesting memories from '{args.input}'...")
176174
print(f"Adding {args.add_turn} irrelevant turns.")
177175
print(f"Using {args.max_workers} workers.")
178-
with open(args.output, "w", encoding="utf-8") as outfile, \
179-
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
180-
futures = [
181-
executor.submit(
182-
add_memory_for_line,
183-
(i, line),
184-
mem_client,
185-
args.add_turn,
186-
args.lib,
187-
args.version,
188-
)
189-
for i, line in enumerate(lines)
190-
]
191-
192-
pbar = tqdm(
193-
concurrent.futures.as_completed(futures),
194-
total=len(lines),
195-
desc="Adding memories...",
176+
with (
177+
open(args.output, "w", encoding="utf-8") as outfile,
178+
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor,
179+
):
180+
futures = [
181+
executor.submit(
182+
add_memory_for_line,
183+
(i, line),
184+
mem_client,
185+
args.add_turn,
186+
args.lib,
187+
args.version,
196188
)
197-
for future in pbar:
198-
result = future.result()
199-
if result:
200-
outfile.write(json.dumps(result, ensure_ascii=False) + "\n")
189+
for i, line in enumerate(lines)
190+
]
191+
192+
pbar = tqdm(
193+
concurrent.futures.as_completed(futures),
194+
total=len(lines),
195+
desc="Adding memories...",
196+
)
197+
for future in pbar:
198+
result = future.result()
199+
if result:
200+
outfile.write(json.dumps(result, ensure_ascii=False) + "\n")
201201
print(f"\n'add' mode complete! Data with user_id written to '{args.output}'.")
202202

203203
elif args.mode == "process":
204204
print(f"Running in 'process' mode. Processing questions from '{args.input}'...")
205205
print(f"Retrieving top {args.top_k} memories for each query.")
206206
print(f"Using {args.max_workers} workers.")
207207
openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL)
208-
with open(args.output, "w", encoding="utf-8") as outfile, \
209-
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
210-
211-
futures = [executor.submit(process_line_with_id, (i, line), mem_client, openai_client, args.top_k, args.lib, args.version) for i, line in enumerate(lines)]
212-
213-
pbar = tqdm(
214-
concurrent.futures.as_completed(futures),
215-
total=len(lines),
216-
desc="Processing questions...",
208+
with (
209+
open(args.output, "w", encoding="utf-8") as outfile,
210+
concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor,
211+
):
212+
futures = [
213+
executor.submit(
214+
process_line_with_id,
215+
(i, line),
216+
mem_client,
217+
openai_client,
218+
args.top_k,
219+
args.lib,
220+
args.version,
217221
)
218-
for future in pbar:
219-
result = future.result()
220-
if result:
221-
outfile.write(json.dumps(result, ensure_ascii=False) + "\n")
222+
for i, line in enumerate(lines)
223+
]
224+
225+
pbar = tqdm(
226+
concurrent.futures.as_completed(futures),
227+
total=len(lines),
228+
desc="Processing questions...",
229+
)
230+
for future in pbar:
231+
result = future.result()
232+
if result:
233+
outfile.write(json.dumps(result, ensure_ascii=False) + "\n")
222234
print(f"\n'process' mode complete! Final results written to '{args.output}'.")
223235

224236

evaluation/scripts/PrefEval/prefeval_preprocess.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def process_jsonl_file(input_filepath, output_filepath):
9292

9393
def main():
9494
huggingface_dataset_name = "siyanzhao/prefeval_implicit_persona"
95-
# output_directory = "./PrefEval"
9695
output_directory = "./data/prefeval"
9796
input_file_path = os.path.join(output_directory, "train.jsonl")
9897
processed_file_path = os.path.join(output_directory, "pref_processed.jsonl")

0 commit comments

Comments
 (0)