Skip to content

Commit b54e2fd

Browse files
committed
[Refactor] Support concurrent inference accorss tasks.
1 parent 9741792 commit b54e2fd

File tree

13 files changed

+1182
-27
lines changed

13 files changed

+1182
-27
lines changed

opencompass/cli/main.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,36 @@
55
import getpass
66
import os
77
import os.path as osp
8+
import threading
89
from datetime import datetime
910

1011
from mmengine.config import Config, DictAction
1112

1213
from opencompass.registry import PARTITIONERS, RUNNERS, build_from_cfg
1314
from opencompass.runners import SlurmRunner
1415
from opencompass.summarizers import DefaultSummarizer
15-
from opencompass.utils import (LarkReporter, get_logger, pretty_print_config,
16-
read_from_station, save_to_station)
16+
from opencompass.utils import (HeartBeatManager, LarkReporter, get_logger,
17+
pretty_print_config, read_from_station,
18+
save_to_station)
1719
from opencompass.utils.run import (fill_eval_cfg, fill_infer_cfg,
1820
get_config_from_arg)
1921

2022

23+
def _run_eval_tasks(runner, tasks):
24+
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(tasks[0],
25+
list):
26+
for task_part in tasks:
27+
runner(task_part)
28+
else:
29+
runner(tasks)
30+
31+
32+
def _is_eval_daemon(task_type) -> bool:
33+
if isinstance(task_type, str):
34+
return task_type.endswith('OpenICLEvalWatchTask')
35+
return getattr(task_type, '__name__', '') == 'OpenICLEvalWatchTask'
36+
37+
2138
def parse_args():
2239
parser = argparse.ArgumentParser(description='Run an evaluation task')
2340
parser.add_argument('config', nargs='?', help='Train config file path')
@@ -318,7 +335,15 @@ def main():
318335
if args.config_verbose:
319336
pretty_print_config(cfg)
320337

321-
# infer
338+
infer_tasks = None
339+
infer_runner = None
340+
eval_tasks = None
341+
eval_runner = None
342+
eval_daemon = False
343+
344+
# ========================
345+
# Setup Configuration
346+
# ========================
322347
if args.mode in ['all', 'infer']:
323348
# When user have specified --slurm or --dlc, or have not set
324349
# "infer" in config, we will provide a default configuration
@@ -358,7 +383,8 @@ def main():
358383
if args.dump_res_length:
359384
for task in tasks:
360385
task.dump_res_length = True
361-
runner(tasks)
386+
infer_tasks = tasks
387+
infer_runner = runner
362388

363389
# evaluate
364390
if args.mode in ['all', 'eval']:
@@ -397,14 +423,35 @@ def main():
397423
if args.dry_run:
398424
return
399425
runner = RUNNERS.build(cfg.eval.runner)
400-
401-
# For meta-review-judge in subjective evaluation
402-
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
403-
tasks[0], list):
404-
for task_part in tasks:
405-
runner(task_part)
406-
else:
407-
runner(tasks)
426+
task_type = getattr(cfg.eval.runner, 'task', {}).get('type', '')
427+
eval_daemon = _is_eval_daemon(task_type)
428+
429+
eval_tasks = tasks
430+
eval_runner = runner
431+
432+
# =================
433+
# Startup Runner
434+
# =================
435+
if infer_runner and eval_runner and eval_daemon:
436+
heartbeat = HeartBeatManager(cfg['work_dir'])
437+
stop_event, hb_thread = heartbeat.start_heartbeat()
438+
439+
eval_thread = threading.Thread(target=_run_eval_tasks,
440+
args=(eval_runner, eval_tasks),
441+
daemon=True)
442+
eval_thread.start()
443+
444+
infer_runner(infer_tasks)
445+
446+
stop_event.set()
447+
hb_thread.join()
448+
logger.info('All infer tasks finished, stop heartbeat.')
449+
eval_thread.join()
450+
else:
451+
if infer_runner is not None:
452+
infer_runner(infer_tasks)
453+
if eval_runner is not None:
454+
_run_eval_tasks(eval_runner, eval_tasks)
408455

409456
# save to station
410457
if args.station_path is not None or cfg.get('station_path') is not None:

opencompass/models/openai_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def generate(
185185
if self.temperature is not None:
186186
temperature = self.temperature
187187

188+
if len(inputs) == 1:
189+
# Forget multi-thread for single infernece.
190+
return [self._generate(inputs[0], max_out_len, temperature)]
191+
188192
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
189193
results = list(
190194
tqdm(
@@ -254,6 +258,7 @@ def _generate(self, input: PromptType, max_out_len: int,
254258
self.org_ctr = 0
255259
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
256260

261+
self.acquire()
257262
try:
258263
if any(model in self.path
259264
for model in OAI_REASONING_MODEL_LIST):
@@ -377,6 +382,8 @@ def _generate(self, input: PromptType, max_out_len: int,
377382
'Find error message in response: ',
378383
str(response['error']),
379384
)
385+
finally:
386+
self.release()
380387
max_num_retries += 1
381388

382389
raise RuntimeError('Calling OpenAI failed after retrying for '
@@ -697,6 +704,7 @@ def _generate(
697704
if self.openai_extra_kwargs:
698705
query_data.update(self.openai_extra_kwargs)
699706

707+
self.acquire()
700708
try:
701709
if self.verbose:
702710
self.logger.info('Start calling OpenAI API')
@@ -789,6 +797,8 @@ def _generate(
789797
except Exception as e:
790798
self.logger.error(f'error occurs at {self.openai_api_base}')
791799
self.logger.error(e)
800+
finally:
801+
self.release()
792802
num_retries += 1
793803
raise RuntimeError('Calling OpenAI API failed after retrying for '
794804
f'{self.retry} times. Check the logs for details.')
@@ -925,6 +935,7 @@ def _generate(
925935
if self.openai_extra_kwargs:
926936
query_data.update(self.openai_extra_kwargs)
927937

938+
self.acquire()
928939
try:
929940
if self.verbose:
930941
self.logger.info('Start calling OpenAI API')
@@ -1052,6 +1063,8 @@ def _generate(
10521063
except Exception as e:
10531064
self.logger.error(f'error occurs at {self.openai_api_base}')
10541065
self.logger.error(e)
1066+
finally:
1067+
self.release()
10551068
num_retries += 1
10561069
raise RuntimeError('Calling OpenAI API failed after retrying for '
10571070
f'{self.retry} times. Check the logs for details.')

opencompass/models/openai_streaming.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(self,
8686
self.openai_extra_kwargs = openai_extra_kwargs
8787
self.timeout = timeout
8888
self.finish_reason_confirm = finish_reason_confirm
89+
self.openai_client = self._create_fresh_client()
8990

9091
def _create_fresh_client(self):
9192
"""Create a fresh OpenAI client for each request to avoid
@@ -117,11 +118,15 @@ def _create_fresh_client(self):
117118
'https://': self.proxy_url,
118119
}
119120

121+
limits = httpx.Limits(max_keepalive_connections=2048,
122+
max_connections=4096)
123+
120124
return OpenAI(
121125
base_url=self.openai_api_base,
122126
api_key=current_key,
123127
http_client=httpx.Client(**http_client_cfg,
124-
timeout=httpx.Timeout(self.timeout))
128+
timeout=httpx.Timeout(self.timeout),
129+
limits=limits)
125130
if http_client_cfg or True else None,
126131
)
127132

@@ -185,6 +190,7 @@ def _generate(
185190
if self.openai_extra_kwargs:
186191
query_data.update(self.openai_extra_kwargs)
187192

193+
self.acquire()
188194
try:
189195
if self.verbose:
190196
thread_id = threading.get_ident()
@@ -193,22 +199,13 @@ def _generate(
193199
f'with streaming enabled')
194200

195201
if self.stream:
196-
# Create fresh client for each request to avoid
197-
# concurrency issues
198-
fresh_client = self._create_fresh_client()
199-
200202
# Handle streaming response with shorter timeout
201-
response_stream = fresh_client.chat.completions.create(
203+
response_stream = self.openai_client.chat.completions.create(
202204
**query_data, timeout=self.timeout)
203205

204206
result = self._handle_stream_response(
205207
response_stream, thread_id if self.verbose else None)
206208

207-
# Clean up the client
208-
if (hasattr(fresh_client, '_client')
209-
and hasattr(fresh_client._client, 'close')):
210-
fresh_client._client.close()
211-
212209
return result
213210
else:
214211
# Fallback to non-streaming (use parent method)
@@ -237,6 +234,8 @@ def _generate(
237234
import traceback
238235
self.logger.error(f'[Thread {thread_id}] Traceback: '
239236
f'{traceback.format_exc()}')
237+
finally:
238+
self.release()
240239
num_retries += 1
241240

242241
raise RuntimeError('Calling OpenAI API failed after retrying for '

opencompass/openicl/icl_inferencer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
from .icl_attack_inferencer import AttackInferencer # noqa
33
from .icl_base_inferencer import BaseInferencer # noqa
44
from .icl_chat_inferencer import ChatInferencer # noqa
5+
from .icl_chat_inferencer_parallel import ParallelChatInferencer # noqa
56
from .icl_chatml_inferencer import ChatMLInferencer # noqa
67
from .icl_clp_inferencer import CLPInferencer # noqa
78
from .icl_gen_inferencer import GenInferencer # noqa
9+
from .icl_gen_inferencer_parallel import ParallelGenInferencer # noqa
810
from .icl_inference_ppl_only_inferencer import \
911
InferencePPLOnlyInferencer # noqa
1012
from .icl_ll_inferencer import LLInferencer # noqa
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Parallel Chat Inferencer."""
2+
import os
3+
import os.path as osp
4+
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
from typing import List, Optional
6+
7+
import mmengine
8+
9+
from opencompass.registry import ICL_INFERENCERS
10+
11+
from ..icl_prompt_template import PromptTemplate
12+
from ..icl_retriever import BaseRetriever
13+
from ..utils.logging import get_logger
14+
from .icl_chat_inferencer import ChatInferencer
15+
16+
logger = get_logger(__name__)
17+
18+
19+
@ICL_INFERENCERS.register_module()
20+
class ParallelChatInferencer(ChatInferencer):
21+
"""Parallel chat inferencer with thread pool over samples."""
22+
23+
def __init__(
24+
self,
25+
model,
26+
output_json_filepath: Optional[str] = './icl_inference_output',
27+
output_json_filename: Optional[str] = 'predictions',
28+
save_every: Optional[int] = 1,
29+
infer_mode: str = 'last',
30+
max_out_len: int = 512,
31+
max_infer_workers: Optional[int] = None,
32+
**kwargs) -> None:
33+
super().__init__(
34+
model=model,
35+
output_json_filename=output_json_filename,
36+
output_json_filepath=output_json_filepath,
37+
save_every=save_every,
38+
infer_mode=infer_mode,
39+
max_out_len=max_out_len,
40+
**kwargs,
41+
)
42+
self.max_infer_workers = max_infer_workers
43+
self.progress_tracker = None
44+
45+
def _resolve_max_workers(self) -> int:
46+
if self.max_infer_workers is not None:
47+
return self.max_infer_workers
48+
max_workers = getattr(self.model, 'max_workers', None)
49+
if max_workers is not None:
50+
return max_workers
51+
cpu_count = os.cpu_count() or 1
52+
return min(32, cpu_count + 4)
53+
54+
def _progress_update(self, count: int = 1) -> None:
55+
if self.progress_tracker is not None:
56+
self.progress_tracker.incr(count)
57+
58+
def inference(self,
59+
retriever: BaseRetriever,
60+
ice_template: Optional[PromptTemplate] = None,
61+
prompt_template: Optional[PromptTemplate] = None,
62+
output_json_filepath: Optional[str] = None,
63+
output_json_filename: Optional[str] = None) -> dict:
64+
output_handler = self.HandlerType()
65+
66+
if output_json_filepath is None:
67+
output_json_filepath = self.output_json_filepath
68+
if output_json_filename is None:
69+
output_json_filename = self.output_json_filename
70+
71+
ice_idx_list = retriever.retrieve()
72+
73+
chat_list = self.get_chat_list(
74+
ice_idx_list,
75+
retriever,
76+
prompt_template=prompt_template,
77+
)
78+
79+
total_samples = len(chat_list)
80+
if self.progress_tracker is not None:
81+
self.progress_tracker.set_total(total_samples)
82+
83+
todo = list(range(total_samples))
84+
tmp_json_filepath = os.path.join(output_json_filepath,
85+
'tmp_' + output_json_filename)
86+
if osp.exists(tmp_json_filepath):
87+
try:
88+
tmp_result_dict = mmengine.load(tmp_json_filepath)
89+
except Exception:
90+
pass
91+
else:
92+
output_handler.results_dict = tmp_result_dict
93+
todo = [i for i in todo if str(i) not in tmp_result_dict.keys()]
94+
if self.progress_tracker is not None:
95+
self.progress_tracker.set_completed(total_samples - len(todo))
96+
97+
chats = [chat_list[i] for i in todo]
98+
99+
logger.info('Starting parallel chat inference process...')
100+
101+
def _infer_one(chat, idx):
102+
local_handler = self.HandlerType()
103+
if self.infer_mode == 'last':
104+
self.infer_last(chat, idx, local_handler)
105+
elif self.infer_mode == 'every':
106+
self.infer_every(chat, idx, local_handler)
107+
elif self.infer_mode == 'every_with_gt':
108+
self.infer_every_with_gt(chat, idx, local_handler)
109+
return local_handler.results_dict
110+
111+
max_workers = self._resolve_max_workers()
112+
completed = total_samples - len(todo)
113+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
114+
futures = [
115+
executor.submit(_infer_one, chat, idx)
116+
for idx, chat in zip(todo, chats)
117+
]
118+
for future in as_completed(futures):
119+
result_dict = future.result()
120+
output_handler.results_dict.update(result_dict)
121+
delta = len(result_dict)
122+
completed += delta
123+
self._progress_update(delta)
124+
if (self.save_every is not None
125+
and completed % self.save_every == 0
126+
and self.is_main_process):
127+
output_handler.write_to_json(output_json_filepath,
128+
'tmp_' + output_json_filename)
129+
130+
if self.is_main_process:
131+
os.makedirs(output_json_filepath, exist_ok=True)
132+
output_handler.write_to_json(output_json_filepath,
133+
output_json_filename)
134+
if osp.exists(tmp_json_filepath):
135+
os.remove(tmp_json_filepath)
136+
137+
return output_handler.results_dict

0 commit comments

Comments
 (0)