Skip to content

Commit 372c62f

Browse files
committed
fix: vllm_eval_webshop without score error
1 parent fe45908 commit 372c62f

File tree

1 file changed

+55
-77
lines changed

1 file changed

+55
-77
lines changed

openmanus_rl/evaluation/vllm_eval_webshop.py

Lines changed: 55 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import concurrent.futures
55
import argparse
6-
import traceback # Import traceback for detailed error logging
6+
import traceback # Import traceback for detailed error logging
77

88
from transformers import AutoTokenizer, GenerationConfig
99

@@ -16,19 +16,15 @@ def evaluate_single_task(model_path, env_server_base, max_rounds, idx):
1616
"""
1717
Initializes necessary components (tokenizer, agent, task, evaluator)
1818
and evaluates a single WebShop task by index.
19-
Returns only the conversation list on success, None otherwise.
19+
Returns the experience data on success, None otherwise.
2020
"""
21-
# print(f"[Task {idx}] Starting evaluation...")
2221
try:
2322
# --- Initialize resources within the worker ---
2423
# Load tokenizer
2524
try:
2625
tokenizer = AutoTokenizer.from_pretrained(model_path)
27-
# print(f"[Task {idx}] Tokenizer loaded successfully.")
2826
except Exception as e:
29-
# print(f"[Task {idx}] Error loading tokenizer from {model_path}: {e}")
30-
# print(f"[Task {idx}] Please ensure the model_path is correct and the model files are accessible.")
31-
return None # Cannot proceed without tokenizer
27+
return None # Cannot proceed without tokenizer
3228

3329
# Define generation config
3430
generation_config = GenerationConfig(
@@ -38,75 +34,60 @@ def evaluate_single_task(model_path, env_server_base, max_rounds, idx):
3834
if tokenizer.pad_token_id is not None
3935
else tokenizer.eos_token_id,
4036
)
41-
# print(f"[Task {idx}] Generation config created.")
4237

4338
# Initialize Agent (model is None as per original script)
4439
agent = Agent(model=None, tokenizer=tokenizer)
45-
# print(f"[Task {idx}] Agent initialized.")
4640

4741
# Initialize WebshopTask
4842
webshop_task = WebshopTask(
4943
client_args={
5044
"env_server_base": env_server_base,
51-
"data_len": 200, # Often unused, can be omitted if causing issues
45+
"data_len": 200, # Often unused, can be omitted if causing issues
5246
"timeout": 300,
5347
},
54-
n_clients=1, # Evaluate one task index at a time
48+
n_clients=1, # Evaluate one task index at a time
5549
)
56-
# print(f"[Task {idx}] WebshopTask initialized.")
5750

5851
# Initialize Evaluator
5952
evaluator = Evaluator(agent, [webshop_task])
60-
# print(f"[Task {idx}] Evaluator initialized.")
61-
# --- End Initialization ---
62-
63-
# Call evaluator.eval for a single index.
64-
# print(f"[Task {idx}] Calling evaluator.eval...")
53+
54+
# Call evaluator.eval for a single index
6555
result = evaluator.eval(
6656
generation_config=generation_config,
6757
max_rounds=max_rounds,
68-
idxs=[idx], # Evaluate only this specific index
58+
idxs=[idx], # Evaluate only this specific index
6959
)
70-
# print(f"[Task {idx}] Evaluator.eval finished.")
7160

72-
# Extract conversation if successful
61+
# Extract experience data if successful
7362
if result and result.experiences:
7463
experience = result.experiences[0]
75-
conversation = getattr(experience, 'conversation', None)
76-
if conversation is not None:
77-
# print(f"[Task {idx}] Evaluation successful, returning conversation.")
78-
return conversation
79-
else:
80-
# print(f"[Task {idx}] Evaluation finished, but no conversation found in experience.")
81-
return None
64+
# Return entire experience object including conversation, reward, and success
65+
return {
66+
"conversation": getattr(experience, 'conversation', None),
67+
"reward": getattr(experience, 'reward', 0.0),
68+
"success": 1 if getattr(experience, 'reward', 0.0) == 1 else 0
69+
}
8270
else:
83-
# print(f"[Task {idx}] Evaluation finished, but no experiences returned.")
8471
return None
8572

8673
except Exception as e:
87-
# print(f"[Task {idx}] Error during evaluation: {e}")
88-
# Print detailed traceback for debugging
8974
traceback.print_exc()
9075
return None
91-
finally:
92-
# Optional: Clean up resources if necessary, though Python's GC might handle it
93-
# for thread-local objects when the thread finishes.
94-
# print(f"[Task {idx}] Evaluation attempt complete.")
95-
pass
9676

9777

9878
def main():
9979
print(f"Current working directory: {os.getcwd()}")
10080

10181
# --- Argument Parsing ---
102-
parser = argparse.ArgumentParser(description='Run WebShop evaluation concurrently, initialize evaluator per worker, and save only conversations to JSONL.')
82+
parser = argparse.ArgumentParser(description='Run WebShop evaluation concurrently, initialize evaluator per worker, and save results to JSONL.')
10383
parser.add_argument('--model_name', type=str, default='Qwen3-8B', help='Name of the model being evaluated (e.g., AgentLM-7B)')
104-
parser.add_argument('--sector', type=str, default='Train', help='Sector or domain of the evaluation (e.g., WebShop)')
105-
parser.add_argument('--num_tasks', type=int, default=100, help='Number of tasks to process (default: 120)')
106-
parser.add_argument('--max_workers', type=int, default=20, help='Maximum number of concurrent workers (default: 120)')
84+
parser.add_argument('--sector', type=str, default='eval', help='Sector or domain of the evaluation (e.g., WebShop)')
85+
parser.add_argument('--num_tasks', type=int, default=100, help='Number of tasks to process (default: 100)')
86+
parser.add_argument('--max_workers', type=int, default=20, help='Maximum number of concurrent workers (default: 20)')
10787
parser.add_argument('--model_path', type=str, default="/data1/models/Qwen/Qwen3-8B-FP8", help='Path to the model directory')
10888
parser.add_argument('--env_server_base', type=str, default="http://127.0.0.1:36001", help='Base URL for the environment server')
10989
parser.add_argument('--max_rounds', type=int, default=7, help='Maximum interaction rounds per task (default: 7)')
90+
parser.add_argument('--output_file', type=str, default="", help='Output file path (default: {model_name}_{sector}.jsonl)')
11091

11192
args = parser.parse_args()
11293

@@ -117,79 +98,76 @@ def main():
11798
max_workers = args.max_workers
11899
model_path = args.model_path
119100
env_server_base = args.env_server_base
120-
max_rounds = args.max_rounds # Use parsed max_rounds
121-
output_filename = f"{model_name}_{sector}.jsonl" # Added _conversations to filename
101+
max_rounds = args.max_rounds
102+
output_filename = args.output_file if args.output_file else f"{model_name}_{sector}.jsonl"
122103

123104
# --- Concurrency Logic ---
124-
all_conversations = [] # Store only the conversations
105+
all_experiences = [] # Store all experience data, not just conversations
125106

126107
print(f"Starting concurrent evaluation of the first {num_tasks_to_process} tasks with {max_workers} workers.")
127108
print(f"Each worker will initialize its own evaluator.")
128-
print(f"Results (conversations only) will be saved to: {output_filename}")
109+
print(f"Results will be saved to: {output_filename}")
129110
print(f"Model path: {model_path}")
130111
print(f"Env server base: {env_server_base}")
131112
print(f"Max rounds per task: {max_rounds}")
132113

114+
# Track success metrics
115+
total_score = 0.0
116+
total_success = 0.0
117+
total_completed = 0
133118

134119
# Use ThreadPoolExecutor for concurrency
135-
# Consider ProcessPoolExecutor if thread-safety issues arise with underlying libraries
136-
# or true process isolation is needed (beware of serialization overhead).
137120
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
138121
# Submit tasks to the executor
139122
future_to_idx = {
140-
# Pass necessary arguments for worker initialization
141123
executor.submit(evaluate_single_task, model_path, env_server_base, max_rounds, i): i
142124
for i in range(num_tasks_to_process)
143125
}
144126

145127
# Process results as they complete
146-
# Using tqdm to show progress
147128
for future in tqdm(concurrent.futures.as_completed(future_to_idx), total=len(future_to_idx), desc="Evaluating tasks"):
148129
idx = future_to_idx[future]
149130
try:
150-
conversation = future.result() # This should be the conversation list or None
151-
if conversation is not None:
152-
# Append the conversation directly (no need for the full experience object)
153-
all_conversations.append(conversation)
154-
# Optional: Add task_id if needed for context, though not strictly requested
155-
# all_conversations.append({"task_id": idx, "conversation": conversation})
131+
experience_data = future.result() # This should be the dictionary with conversation, reward, success
132+
if experience_data is not None and experience_data["conversation"] is not None:
133+
# Add task_id to experience data
134+
experience_data["item_id"] = f"webshop_{idx}"
135+
all_experiences.append(experience_data)
136+
137+
# Update metrics
138+
total_score += experience_data["reward"]
139+
total_success += experience_data["success"]
140+
total_completed += 1
156141
else:
157-
# Task failed or returned no conversation, already logged in the function
158-
print(f"Task {idx} completed but returned no conversation data.")
142+
print(f"Task {idx} completed but returned no valid data.")
159143
except Exception as exc:
160-
# This catches errors during future.result() itself, though evaluate_single_task has internal try-except
161144
print(f'Task {idx} generated an exception during future processing: {exc}')
162145
traceback.print_exc()
163146

164-
165-
print(f"\n==== CONCURRENT EVALUATION COMPLETE (Collected {len(all_conversations)} conversations) ====\n")
147+
print(f"\n==== CONCURRENT EVALUATION COMPLETE (Collected {len(all_experiences)} experiences) ====\n")
166148

167149
# --- Save Results to JSONL ---
168-
if all_conversations:
169-
print(f"Saving {len(all_conversations)} conversations to {output_filename}")
150+
if all_experiences:
151+
print(f"Saving {len(all_experiences)} experiences to {output_filename}")
170152
try:
171153
with open(output_filename, 'w') as f:
172-
for i, conv in enumerate(all_conversations):
173-
# Create a dictionary containing only the conversation for each line
174-
# Adding an index might be helpful for reference, but not strictly required
175-
line_data = {
176-
# "original_task_index": i, # Example if you want to track original submission order index
177-
"conversation": conv
178-
}
179-
f.write(json.dumps(line_data) + '\n')
180-
print(f"Successfully saved conversations to {output_filename}")
154+
for exp in all_experiences:
155+
f.write(json.dumps(exp) + '\n')
156+
print(f"Successfully saved experiences to {output_filename}")
181157
except Exception as e:
182158
print(f"Error saving results to {output_filename}: {e}")
183159
traceback.print_exc()
184160
else:
185-
print("No conversations were collected to save.")
186-
187-
# Example: Print summary
188-
total_tasks_attempted = num_tasks_to_process
189-
total_conversations_collected = len(all_conversations)
190-
print(f"\nSuccessfully collected conversations for {total_conversations_collected} tasks out of {total_tasks_attempted} attempted.")
191-
192-
# No need to print example conversation as per requirements
161+
print("No experiences were collected to save.")
162+
163+
# Calculate and print evaluation metrics
164+
if total_completed > 0:
165+
score_average = total_score / total_completed
166+
success_rate = total_success / total_completed
167+
print("\n\n==== EVALUATION ====\n")
168+
print(f"Score: {score_average:.4f}")
169+
print(f"Success: {success_rate:.4f}")
170+
print(f"Completed Tasks: {total_completed}/{num_tasks_to_process}")
193171

194172

195173
if __name__ == "__main__":

0 commit comments

Comments
 (0)