|
| 1 | +def test_benchmark(infer_backend): |
| 2 | + import os |
| 3 | + os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
| 4 | + os.environ['TIMEOUT'] = '-1' |
| 5 | + import requests |
| 6 | + from swift.llm import DeployArguments, get_dataset, get_model_list_client, XRequestConfig, inference_client_async |
| 7 | + from swift.llm.deploy import llm_deploy |
| 8 | + import multiprocessing |
| 9 | + import time |
| 10 | + import asyncio |
| 11 | + from swift.utils import get_logger |
| 12 | + |
| 13 | + logger = get_logger() |
| 14 | + |
| 15 | + mp = multiprocessing.get_context('spawn') |
| 16 | + process = mp.Process( |
| 17 | + target=llm_deploy, |
| 18 | + args=(DeployArguments(model_type='qwen2-7b-instruct', infer_backend=infer_backend, verbose=False), )) |
| 19 | + process.start() |
| 20 | + |
| 21 | + dataset = get_dataset(['alpaca-zh#1000', 'alpaca-en#1000'])[0] |
| 22 | + query_list = dataset['query'] |
| 23 | + request_config = XRequestConfig(seed=42, max_tokens=8192) |
| 24 | + |
| 25 | + while True: |
| 26 | + try: |
| 27 | + model_list = get_model_list_client() |
| 28 | + except requests.exceptions.ConnectionError: |
| 29 | + time.sleep(5) |
| 30 | + continue |
| 31 | + break |
| 32 | + model_type = model_list.data[0].id |
| 33 | + is_chat = model_list.data[0].is_chat |
| 34 | + is_multimodal = model_list.data[0].is_multimodal |
| 35 | + print(f'model_type: {model_type}') |
| 36 | + |
| 37 | + tasks = [] |
| 38 | + for query in query_list: |
| 39 | + tasks.append( |
| 40 | + inference_client_async( |
| 41 | + model_type, query, request_config=request_config, is_chat=is_chat, is_multimodal=is_multimodal)) |
| 42 | + |
| 43 | + async def _batch_run(tasks): |
| 44 | + return await asyncio.gather(*tasks) |
| 45 | + |
| 46 | + resp_list = asyncio.run(_batch_run(tasks)) |
| 47 | + logger.info(f'len(resp_list): {len(resp_list)}') |
| 48 | + logger.info(f'resp_list[0]: {resp_list[0]}') |
| 49 | + process.terminate() |
| 50 | + |
| 51 | + |
| 52 | +def test_vllm_benchmark(): |
| 53 | + test_benchmark('vllm') |
| 54 | + |
| 55 | + |
| 56 | +def test_lmdeploy_benchmark(): |
| 57 | + test_benchmark('lmdeploy') |
| 58 | + |
| 59 | + |
| 60 | +if __name__ == '__main__': |
| 61 | + # test_vllm_benchmark() |
| 62 | + test_lmdeploy_benchmark() |
0 commit comments