diff --git a/mii/client.py b/mii/client.py index b914a1a7..3cb47589 100644 --- a/mii/client.py +++ b/mii/client.py @@ -10,6 +10,7 @@ from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc from mii.constants import GRPC_MAX_MSG_SIZE from mii.method_table import GRPC_METHOD_TABLE +from mii.event_loop import get_event_loop def _get_deployment_info(deployment_name): @@ -57,7 +58,7 @@ class MIIClient(): Client to send queries to a single endpoint. """ def __init__(self, task_name, host, port): - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() channel = create_channel(host, port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) self.task = get_task(task_name) @@ -74,17 +75,22 @@ async def _request_async_response(self, request_dict, **query_kwargs): proto_response ) if "unpack_response_from_proto" in conversions else proto_response - def query(self, request_dict, **query_kwargs): - return self.asyncio_loop.run_until_complete( + def query_async(self, request_dict, **query_kwargs): + return asyncio.run_coroutine_threadsafe( self._request_async_response(request_dict, - **query_kwargs)) + **query_kwargs), + get_event_loop()) + + def query(self, request_dict, **query_kwargs): + return self.query_async(request_dict, **query_kwargs).result() async def terminate_async(self): await self.stub.Terminate( modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) def terminate(self): - self.asyncio_loop.run_until_complete(self.terminate_async()) + asyncio.run_coroutine_threadsafe(self.terminate_async(), + get_event_loop()).result() class MIITensorParallelClient(): @@ -95,7 +101,7 @@ class MIITensorParallelClient(): def __init__(self, task_name, host, ports): self.task = get_task(task_name) self.clients = [MIIClient(task_name, host, port) for port in ports] - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() # runs task in parallel and return the result from the first task async def _query_in_tensor_parallel(self, request_string, query_kwargs): @@ -107,7 +113,16 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs): **query_kwargs))) await responses[0] - return responses[0] + return responses[0].result() + + def query_async(self, request_dict, **query_kwargs): + """Asynchronously auery a local deployment. + See `query` for the arguments and the return value. + """ + return asyncio.run_coroutine_threadsafe( + self._query_in_tensor_parallel(request_dict, + query_kwargs), + self.asyncio_loop) def query(self, request_dict, **query_kwargs): """Query a local deployment: @@ -122,11 +137,7 @@ def query(self, request_dict, **query_kwargs): Returns: response: Response of the model """ - response = self.asyncio_loop.run_until_complete( - self._query_in_tensor_parallel(request_dict, - query_kwargs)) - ret = response.result() - return ret + return self.query_async(request_dict, **query_kwargs).result() def terminate(self): """Terminates the deployment""" @@ -136,5 +147,5 @@ def terminate(self): def terminate_restful_gateway(deployment_name): _, mii_configs = _get_deployment_info(deployment_name) - if mii_configs.restful_api_port > 0: + if mii_configs.enable_restful_api: requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate") diff --git a/mii/event_loop.py b/mii/event_loop.py new file mode 100644 index 00000000..4040c86f --- /dev/null +++ b/mii/event_loop.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import asyncio +import threading + +global event_loop +event_loop = asyncio.get_event_loop() +threading.Thread(target=event_loop.run_forever, daemon=True).start() + + +def get_event_loop(): + return event_loop diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index bbbf7dfe..c4e73888 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -17,6 +17,7 @@ from mii.method_table import GRPC_METHOD_TABLE from mii.client import create_channel from mii.utils import get_task +from mii.event_loop import get_event_loop class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): @@ -42,6 +43,7 @@ def __init__(self, inference_pipeline): super().__init__() self.inference_pipeline = inference_pipeline self.method_name_to_task = {m["method"]: t for t, m in GRPC_METHOD_TABLE.items()} + self.lock = threading.Lock() def _get_model_time(self, model, sum_times=False): model_times = [] @@ -72,7 +74,8 @@ def _run_inference(self, method_name, request_proto): args, kwargs = conversions["unpack_request_from_proto"](request_proto) start = time.time() - response = self.inference_pipeline(*args, **kwargs) + with self.lock: + response = self.inference_pipeline(*args, **kwargs) end = time.time() model_time = self._get_model_time(self.inference_pipeline.model, @@ -134,7 +137,7 @@ def __init__(self, host, ports): stub = modelresponse_pb2_grpc.ModelResponseStub(channel) self.stubs.append(stub) - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() async def _invoke_async(self, method_name, proto_request): responses = [] @@ -154,7 +157,7 @@ def invoke(self, method_name, proto_request): class LoadBalancingInterceptor(grpc.ServerInterceptor): def __init__(self, task_name, replica_configs): super().__init__() - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() self.stubs = [ ParallelStubInvoker(replica.hostname, @@ -164,13 +167,6 @@ def __init__(self, task_name, replica_configs): self.counter = AtomicCounter() self.task = get_task(task_name) - # Start the asyncio loop in a separate thread - def run_asyncio_loop(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start() - def choose_stub(self, call_count): return self.stubs[call_count % len(self.stubs)]