Skip to content

Commit 010ff8c

Browse files
committed
feat: use custom http_client
This commit adds the ability to pass a custom HTTP client to the MT-Bench evaluator. This is handy when using custom certificates when interacting with the judge model serving endpoint. Signed-off-by: Sébastien Han <[email protected]>
1 parent cd0487e commit 010ff8c

File tree

8 files changed

+41
-5
lines changed

8 files changed

+41
-5
lines changed

.spellcheck-en-custom.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dr
1010
eval
1111
gpt
1212
hoc
13+
http
1314
instructlab
1415
jsonl
1516
justfile

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
## 0.4
2+
3+
* Added ability to specify a custom http client to MT-Bench
4+
15
## v0.2

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ accelerate
99
pandas
1010
pandas-stubs
1111
lm-eval>=0.4.4
12+
httpx

src/instructlab/eval/mt_bench.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import multiprocessing
1111
import os
1212

13+
# Third Party
14+
import httpx
15+
1316
# First Party
1417
from instructlab.eval import (
1518
mt_bench_answers,
@@ -110,6 +113,7 @@ def gen_answers(
110113
api_key: str | None = None,
111114
max_workers: int | str | None = None,
112115
serving_gpus: int | None = None,
116+
http_client: httpx.Client | None = None,
113117
) -> None:
114118
"""
115119
Asks questions to model
@@ -119,6 +123,7 @@ def gen_answers(
119123
api_key API token for authenticating with model server
120124
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
121125
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
126+
http_client Custom http client to use for requests
122127
"""
123128
logger.debug(locals())
124129
mt_bench_answers.generate_answers(
@@ -127,6 +132,7 @@ def gen_answers(
127132
api_key=api_key,
128133
output_dir=self.output_dir,
129134
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
135+
http_client=http_client,
130136
)
131137

132138
def judge_answers(
@@ -135,6 +141,7 @@ def judge_answers(
135141
api_key: str | None = None,
136142
max_workers: int | str | None = None,
137143
serving_gpus: int | None = None,
144+
http_client: httpx.Client | None = None,
138145
) -> tuple:
139146
"""
140147
Runs MT-Bench judgment
@@ -144,6 +151,7 @@ def judge_answers(
144151
api_key API token for authenticating with model server
145152
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
146153
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
154+
http_client Custom http client to use for requests
147155
148156
Returns:
149157
overall_score MT-Bench score for the overall model evaluation
@@ -160,6 +168,7 @@ def judge_answers(
160168
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
161169
output_dir=self.output_dir,
162170
merge_system_user_message=self.merge_system_user_message,
171+
http_client=http_client,
163172
)
164173

165174

@@ -202,6 +211,7 @@ def gen_answers(
202211
api_key: str | None = None,
203212
max_workers: int | str | None = None,
204213
serving_gpus: int | None = None,
214+
http_client: httpx.Client | None = None,
205215
) -> None:
206216
"""
207217
Asks questions to model
@@ -211,6 +221,7 @@ def gen_answers(
211221
api_key API token for authenticating with model server
212222
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
213223
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
224+
http_client Custom http client to use for requests
214225
"""
215226
logger.debug(locals())
216227
mt_bench_branch_generator.generate(
@@ -228,6 +239,7 @@ def gen_answers(
228239
data_dir=self.output_dir,
229240
max_workers=self._get_effective_max_workers(max_workers, serving_gpus),
230241
bench_name="mt_bench_branch",
242+
http_client=http_client,
231243
)
232244

233245
def judge_answers(
@@ -236,6 +248,7 @@ def judge_answers(
236248
api_key: str | None = None,
237249
max_workers: int | str | None = None,
238250
serving_gpus: int | None = None,
251+
http_client: httpx.Client | None = None,
239252
) -> tuple:
240253
"""
241254
Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name.
@@ -245,6 +258,7 @@ def judge_answers(
245258
api_key API token for authenticating with model server
246259
max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor.
247260
serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor.
261+
http_client Custom http client to use for requests
248262
249263
Returns:
250264
overall_score Overall score from the evaluation
@@ -263,5 +277,6 @@ def judge_answers(
263277
data_dir=self.output_dir,
264278
bench_name="mt_bench_branch",
265279
merge_system_user_message=self.merge_system_user_message,
280+
http_client=http_client,
266281
)
267282
return overall_score, qa_pairs, error_rate

src/instructlab/eval/mt_bench_answers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,12 @@ def generate_answers(
108108
max_tokens=1024,
109109
max_workers=1,
110110
bench_name="mt_bench",
111+
http_client=None,
111112
):
112113
"""Generate model answers to be judged"""
113114
logger.debug(locals())
114115

115-
openai_client = get_openai_client(model_api_base, api_key)
116+
openai_client = get_openai_client(model_api_base, api_key, http_client)
116117

117118
if data_dir is None:
118119
data_dir = os.path.join(os.path.dirname(__file__), "data")

src/instructlab/eval/mt_bench_common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import time
1414

1515
# Third Party
16+
import httpx
1617
import openai
1718

1819
# First Party
@@ -365,8 +366,14 @@ def get_model_list(answer_file):
365366
return [os.path.splitext(os.path.basename(answer_file))[0]]
366367

367368

368-
def get_openai_client(model_api_base, api_key):
369+
def get_openai_client(
370+
model_api_base,
371+
api_key,
372+
http_client: httpx.Client | None = None,
373+
):
369374
if api_key is None:
370375
api_key = "NO_API_KEY"
371-
openai_client = openai.OpenAI(base_url=model_api_base, api_key=api_key)
376+
openai_client = openai.OpenAI(
377+
base_url=model_api_base, api_key=api_key, http_client=http_client
378+
)
372379
return openai_client

src/instructlab/eval/mt_bench_judgment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,12 @@ def generate_judgment(
286286
max_workers=1,
287287
first_n=None,
288288
merge_system_user_message=False,
289+
http_client=None,
289290
):
290291
"""Generate judgment with scores and qa_pairs for a model"""
291292
logger.debug(locals())
292293

293-
openai_client = get_openai_client(model_api_base, api_key)
294+
openai_client = get_openai_client(model_api_base, api_key, http_client)
294295

295296
first_n_env = os.environ.get("INSTRUCTLAB_EVAL_FIRST_N_QUESTIONS")
296297
if first_n_env is not None and first_n is None:

tests/test_branch_gen_answers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Third Party
2+
import httpx
3+
14
# First Party
25
from instructlab.eval.mt_bench import MTBenchBranchEvaluator
36

@@ -7,4 +10,7 @@
710
"../taxonomy",
811
"main",
912
)
10-
mt_bench_branch.gen_answers("http://localhost:8000/v1")
13+
mt_bench_branch.gen_answers(
14+
"http://localhost:8000/v1",
15+
http_client=httpx.Client(verify=False),
16+
)

0 commit comments

Comments
 (0)