Skip to content

Commit 6eda33a

Browse files
raise_for_status
1 parent 99336fa commit 6eda33a

File tree

1 file changed

+4
-21
lines changed

1 file changed

+4
-21
lines changed

scripts/server-bench.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
9292
f"{server_address}/apply-template",
9393
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
9494
)
95-
if response.status_code != 200:
96-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
95+
response.raise_for_status()
9796
prompt: str = json.loads(response.text)["prompt"]
9897
response = session.post(
9998
f"{server_address}/tokenize",
10099
json={"content": prompt, "add_special": True}
101100
)
102-
if response.status_code != 200:
103-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
101+
response.raise_for_status()
104102
tokens: list[str] = json.loads(response.text)["tokens"]
105103
return len(tokens)
106104

@@ -125,18 +123,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
125123
f"{server_address}/apply-template",
126124
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
127125
)
128-
if response.status_code != 200:
129-
response_text = ""
130-
try:
131-
response_text = response.text
132-
response_text = ": {response_text}"
133-
except RuntimeError:
134-
pass
135-
raise RuntimeError(f"Server returned status code {response.status_code}{response_text}")
126+
response.raise_for_status()
136127
prompt: str = json.loads(response.text)["prompt"]
137128

138129
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
139130
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
131+
response.raise_for_status()
140132

141133
lines = []
142134
token_arrival_times: list[float] = []
@@ -149,15 +141,6 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
149141
if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
150142
token_arrival_times = token_arrival_times[:-1]
151143

152-
if response.status_code != 200:
153-
response_text = ""
154-
try:
155-
response_text = response.text
156-
response_text = ": {response_text}"
157-
except RuntimeError:
158-
pass
159-
raise RuntimeError(f"Server returned status code {response.status_code}{response_text}")
160-
161144
return (t_submit, token_arrival_times)
162145

163146

0 commit comments

Comments
 (0)