Skip to content

Commit 139e55a

Browse files
committed
refactor: update chat model arguments and enable relaunch in experiment script; add rate limit testing functionality
1 parent 6aae3cf commit 139e55a

File tree

2 files changed

+97
-5
lines changed

2 files changed

+97
-5
lines changed

main_exp_new_models.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,21 @@
1717

1818
logging.getLogger().setLevel(logging.INFO)
1919

20+
# chat_model_args = CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-mini-2025-04-14"]
21+
# chat_model_args = CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-2025-04-14"]
22+
chat_model_args = CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.7-sonnet"]
2023
agent_args = [
2124
GenericAgentArgs(
22-
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4.1-mini-2025-04-14"],
25+
chat_model_args=chat_model_args,
2326
flags=FLAGS_GPT_4o,
2427
)
2528
]
2629

2730

2831
# ## select the benchmark to run on
29-
benchmark = "miniwob_tiny_test"
32+
# benchmark = "miniwob_tiny_test"
3033
# benchmark = "miniwob"
31-
# benchmark = "workarena_l1"
34+
benchmark = "workarena_l1"
3235
# benchmark = "workarena_l2"
3336
# benchmark = "workarena_l3"
3437
# benchmark = "webarena"
@@ -40,10 +43,10 @@
4043

4144
# Set relaunch = True to relaunch an existing study, this will continue incomplete
4245
# experiments and relaunch errored experiments
43-
relaunch = False
46+
relaunch = True
4447

4548
## Number of parallel jobs
46-
n_jobs = 4 # Make sure to use 1 job when debugging in VSCode
49+
n_jobs = 5 # Make sure to use 1 job when debugging in VSCode
4750
# n_jobs = -1 # to use all available cores
4851

4952

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import time
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
5+
import anthropic
6+
7+
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
8+
9+
10+
def make_request(messages):
11+
response = client.messages.create(
12+
model="claude-3-5-sonnet-20241022", max_tokens=10, messages=messages
13+
)
14+
return response.usage
15+
16+
17+
def make_message(text):
18+
return {
19+
"role": "user",
20+
"content": [
21+
{
22+
"type": "text",
23+
"text": text,
24+
}
25+
],
26+
}
27+
28+
29+
def add_cache_control(message: dict, cache_type="ephemeral"):
30+
message["content"][0]["cache_control"] = {"type": cache_type}
31+
32+
33+
def remove_cache_control(message: dict):
34+
if "cache_control" in message["content"][0]:
35+
del message["content"][0]["cache_control"]
36+
37+
38+
def test_rate_limit_single(thread_id):
39+
# Create ~100k token message that will be cached
40+
big_text = "This is a large block of text for caching. " * 10000 # ~100k tokens
41+
medium_text = "This is a large block of text for caching. " * 2000 # ~10k tokens
42+
43+
print(f"Thread {thread_id}: Starting rate limit test with cached content...")
44+
45+
# Rebuild conversation each time (simulating web agent)
46+
messages = []
47+
48+
# Add all previous conversation turns
49+
for i in range(5):
50+
if i == 0:
51+
messages.append(make_message(big_text))
52+
t0 = time.time()
53+
else:
54+
messages.append(make_message(medium_text))
55+
add_cache_control(messages[-1])
56+
try:
57+
usage = make_request(messages)
58+
dt = time.time() - t0
59+
print(f"{dt:.2f}: Thread {thread_id}: {usage}")
60+
except Exception as e:
61+
print(f"Thread {thread_id}: Error - {e}")
62+
break
63+
remove_cache_control(messages[-1])
64+
65+
66+
def test_rate_limit_parallel(num_threads=3):
67+
print(f"Starting parallel rate limit test with {num_threads} threads...")
68+
69+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
70+
futures = [executor.submit(test_rate_limit_single, i) for i in range(num_threads)]
71+
72+
for future in as_completed(futures):
73+
try:
74+
future.result()
75+
except Exception as e:
76+
print(f"Thread completed with error: {e}")
77+
78+
79+
def test_rate_limit():
80+
# Original single-threaded version
81+
test_rate_limit_single(0)
82+
83+
84+
if __name__ == "__main__":
85+
# Use parallel version to quickly exhaust rate limits
86+
test_rate_limit_parallel(num_threads=3)
87+
88+
# Or use original single-threaded version
89+
# test_rate_limit()

0 commit comments

Comments
 (0)